搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(八)

作者丨科技猛兽
审稿丨邓富城
编辑丨极市平台

极市导读

本文为详细解读Vision Transformer的第八篇,本文主要介绍了两个用以加深Transformer模型的工作:DeepViT、CaiT。>>加入极市CV技术交流群,走在计算机视觉的最前沿

考虑到每篇文章字数的限制,每一篇文章将按照目录的编排包含三个小节,而且这个系列会随着Vision Transformer的发展而长期更新。
搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(一)
搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(二)
搞懂Vision Transformer 原理和代码,看这篇技术综述就够了(三)
搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(四)
搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(五)
搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(六)
搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(七)

本文目录

17 DeepViT: 解决注意力坍塌以构建深层ViT
(来自 新加坡国立大学, 字节跳动AI Lab(美国))
17.1 DeepViT原理分析

18 CaiT:Going deeper with Image Transformers
(来自 Facebook)
18.1 CaiT原理分析
18.2 CaiT代码解读

Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。

本文介绍的两个工作工作的初衷是为了加深Transformer模型。在模型加深的过程中可以同时保证稳定训练和精度提升。深度学习的基本思想之一就是让模型更深一点,这在CNN中已经很好的体现。例如2012年图像分类冠军AlexNet就是一个更深的模型。著名的ResNet本质上也是一种如何训练更深的CNN模型的方法。

但是反观基于Transformer的模型ViT,DeiT,它们不论是base,small还是tiny尺寸的模型,大小都是只有12个layers。如果直接把depth加深的话,模型的性能会迅速地饱和。那么一个自然而然的问题就是:能不能像CNN一样,采取一些什么手段来加深Transformer模型呢?

17 DeepViT: 解决注意力坍塌以构建深层ViT

论文名称:17 DeepViT: Towards Deeper Vision Transformer

论文地址:

DeepViT: Towards Deeper Vision Transformer

https://arxiv.org/abs/2103.11886

  • 17.1 Deep-ViT原理分析:

Motivation

这个工作的初衷是为了加深Transformer模型。为什么这么做的原因论文中没有明显的交代,但是深度学习的基本思想之一就是让模型更深一点,这在CNN中已经很好的体现。例如2012年图像分类冠军AlexNet就是一个更深的模型。著名的ResNet本质上也是一种如何训练更深的CNN模型的方法。

但是反观基于Transformer的模型ViT,DeiT,它们不论是base,small还是tiny尺寸的模型,大小都是只有12个layers。如果直接把depth加深的话,模型的性能会迅速地饱和,甚至32层ViT的性能还不如24层的模型,如下图1所示。那么一个自然而然的问题就是:能不能像CNN一样,采取一些什么手段来加深Transformer模型呢?

图1:不同深度下的ViT在ImageNet的性能

为了解答这个问题,首先要弄清楚的是:为啥简单地加深ViT性能会出现饱和呢?

原因是作者发现:模型的注意力图 (attention maps),在深层时会非常相似,甚至是变成一样的。这表明,随着ViT的深入,self-attention 在生成不同注意力以捕获rich representation方面变得不那么有效了。作者称之为注意力坍塌 (attention collapse)。

具体来讲,作者衡量不同layer直接attention map的相似性的方法是计算余弦相似度:

式中  代表第  个layer,  代表第  个head,  代表token。

是个  维度的张量,  是个  维的向量,代表第  个head的第  个输入的token对  个output token的注意力大小。所以  指的是第  层的第  个head的第  个输入的这个值,而  指的是第  层的第  个head的第  个输入的这个值。二者做cosine值以后记得到的  为  层和  层的第  个head的第  个输入的这个值的相似度。 当  越大时,代表相似度越高;当  越小时,代表相似度越低。当  时,代表token  在  层和  层所起的作用完全一致。

根据这样的定义,作者使用 ImageNet-1k 训练了一个32层的ViT模型。下图2的意义论文里讲得很含糊,我推测是每个点代表一种ratio,这个ratio是指这个层的token与周围  个层的对应token的相似性。该图表明了随着深度的增加,attention map和  个附近的block的attention map越来越相似。

下图3代表的是这32个层里面,随着层数的加深,具有相似的attention map的block的数量(红色线),以及跨层的相似度(黑色线)。比如最右边的红色点表示在到了第32层的时候,相似的block就达到了16个,而且跨层相似度达到了近0.5左右。

图3:随着ViT模型深度的增加,相似块与块总数的比率也会增加。

下图4代表的是同1个block内部不同head的的注意力图相似性,发现块内不同头之间的相似性都低于30%,它们呈现出足够的多样性。

图4:同1个block内部不同head的的注意力图相似性

Attention collapse这个问题究竟是如何影响ViT模型的性能的呢?

如图5所示的黑色线代表最后一个特征图与前面每个特征图的相似度;红色线代表相邻的层中具有相似的attention map的比例。从大概20层之后,这个相似度变得非常高,使得学习到的特征图和attention map都差不多了。这一观察表明,attention collapse这个问题是ViT不可扩展问题的原因。

图5:最后一个特征图/attention map与前面每个特征图/attention map的相似度

作者又对注意力坍塌这个问题做了进一步分析:如下图6所示的shared blocks这个参数代表把模型最后的这么多层的  和  设置为一模一样的值以后得到的Acc。比如说第2行的这个11,就代表把一个24层的ViT模型的最后11层的  和  设置为一模一样的值,使得最后11层的attention map也是一样的。发现掉点并不严重,这表明最后深层的attention map其实本来就很接近,即:注意力坍塌这个问题确实是存在的,使得简单地加深ViT模型是无效的。

图6:共享attention map的结果

为了解决这个问题,作者提出了2种解决办法:

第1种是增加self-attention模块的embedding dimension,这样做的目的是为了增加每个token的表达能力。这样一来,生成的注意力图可以更加多样化,每个块的attention map之间的相似性可以减少。作者设计了4种不同的embedding dimension,分别是256,384,512,768。如下图7所示,随着embedding dimension的增长,相似的block的数量在下降,同时模型的Acc在上升。换句话讲,注意力坍塌的问题得以缓解,这证明了本文的论点,即:注意力坍塌问题是影响ViT模型scaling的原因之一。

尽管增加embedding dimension是有效的,但增加embedding dimension也会显著增加计算成本,带来的性能改进往往会减少,且需要更大的数据量来训练,增加了过拟合的风险。

图7:不同的embedding dimension以及相似度和Acc

鉴于此,作者提出了第2种方法。第2种是通过一个改进的模块,叫做Re-attention。 其想法来自于作者发现同一个层的不同head之间的相似度比较小,所以说来自同一自注意力层的不同head关注输入token 的不同方面。那如果可以把不同的head的信息结合起来,利用它们再生出一个attention map,是不是能够避免注意力坍塌问题呢?

所以,Re-attention以一种可学习的方式交换来自不同attention heads的信息再生注意力图。具体来讲,作者把得到的attention map 经过一步Linear Transformation  ,得到再生的attention map,写成公式就是:

注意这个可学习的  是作用在head这个维度上的。Norm取得是BatchNorm而不是传统的LayerNorm。

这样做维度好处是充分利用了不同head之间的信息,因为这些信息的相似度低,所以可以有效地缓解注意力坍塌的问题,而且很容易实现,也不用像增加embedding diemnsion一样增加那么多的参数量和计算量。

图8:使用了Re-attention以后的ViT模型

下图9为原始ViT模型和DeepViT的Feature map的跨层余弦相似性,可以看出,用Re-attention取代原始的Self-attention可以显著降低特征Feature map的相似性。

图9:原始ViT模型和DeepViT的Feature map的跨层余弦相似性

如下图10所示为普通的self-attention与Re-attention的attention map的可视化结果。

普通的self-attention的attention map的特点是:只学习到一些局部的patch之间的relationship,而且在网络的深层时attention map很接近。

Re-attention的attention map的特点是:学习到更大范围内的patch之间的relationship,而且在网络的深层时attention map有差别。

Re-attention使得网络深层的attention map保持了多样性,并与相邻块具有更小的相似性。

图10:attention map的可视化结果

Experiments:

优化器: AdamW

初始学习率: 5e-4

学习率衰减策略: cosine

batch size: 256

模型配置如下图11所示:

图11:模型配置

作者进行的一个很重要的是研究是分析Re-attention的作用。就是想看看加了Re-attention以后,模型会不会随着层数的加深而达到饱和,会不会随着层数的加深而相似度逐渐升高?

结果如下图12所示,传统的ViT模型当层数分别加深为16,24,32层时,相似的块的数量分别是5,11,16,也就是说不相似的只有前11层,且Acc不会继续增长。但是使用了Re-attention以后,没有出现相似的块,换句话说,Re-attention使得更深的模型的每个block都是有贡献的。这使得随着模型的加深,Acc也会逐渐增高,说明Re-attention可以解决注意力坍塌的问题。

图12:Re-attention使得更深的模型的每个block都是有贡献的

作者也尝试了一种解决注意力坍塌问题的办法,就是通过给self-attention加一个temperature  :

的作用是什么呢?

的减小可以让attention map的分布变得更加的sharp,使得attention map的数值拉开,在实际实现的时候作者使得  的值随着层数的加深而逐渐衰减,或者干脆把  设置为可学习的参数。

作者也尝试了一种解决注意力坍塌问题的办法,就是随机drop掉attention map中的一些值。因为对于不同的block来说,会随机drop掉不同的值,故attention maps之间的相似性可以得到减少。

这2种做法的结果如下图13中的绿线所示和图14所示,无论是加上temperature参数,还是采用drop attention的操作,相似的block的数量仍然很多,以及不同block之间的相似性仍然很大。所以,加上这个不断衰减的  参数对缓解注意力坍塌问题来讲只起到非常小的作用。

采用drop attention确实可以在一定程度上减弱不同blocks之间的相似性,但是attention map之间相似性的减弱来源于mask位置的不同,这对feature map影响不大。也就是说,深层的,不同层的feature map还是非常相似的。

图13:给self-attention加上temperature参数或者采用drop attention后的不同block之间的相似度
图14:给self-attention加上temperature参数或者采用drop attention后的相似的block的数量

最后作者可视化了最后一层的Transformation matrix  的样子,如图15所示,  为每个head分配了不同的权重,使得再生过程改变了不同的head的权重。比如说第31层是一个  ,第32层是一个不同的  ,凭借这一点来降低block之间的attention map的相似性。

图15:Transformation matrix

与SOTA的对比:

作者提出了两种尺度的模型,DeepViT-S 和 DeepViT-L 层数分别是16层和32层。那么加深ViT模型会使得参数量急剧上升,为了保持和ViT的参数量大致相当,作者把embedded dimension做小了一点,ViT-B 和ViT-L的embedded dimension分别是768和1024,而DeepViT-S 和 DeepViT-L 的 embedded dimension分别是396和408。结果如下:

相比于ViT-B 和ViT-L,DeepViT-S 和 DeepViT-L 以更小的参数量获得了更好的性能。

图16:与SOTA的对比

总结:

本文提出了一种加深ViT模型的方法。所基于的论点是作者发现随着模型的加深,以一个32层模型为例,深层的 (17-32层)的不同层之间的attention map的相似度很高,然而不同的head之间的相似度却很低。基于此,作者通过一个转移矩阵,把不同的head的信息给结合起来了。借此,使得深层的block里面,不同的层训练出来的这个转移矩阵是不同的,达到异化attention map的目的。所以,DeepViT能够有效缓解注意力坍塌,使得ViT模型的性能可以随着层数的加深而增加。

18 CaiT:Going deeper with Image Transformers

论文名称:Going deeper with Image Transformers

论文地址:

Going deeper with Image Transformersa

https://arxiv.org/abs/2103.17239

  • 18.1 CaiT原理分析:

18.1.1 优秀前作DeiT

CaiTDeiT一样都是来自Facebook的同一团队,一作都是Hugo Touvron。

关于DeiT论文,论文和代码的详细解读请读者们参考下文:

Training data-efficient image transformers & distillation through attention

https://arxiv.org/abs/2012.12877

科技猛兽:Vision Transformer 超详细解读 (原理分析+代码解读) (三)

https://zhuanlan.zhihu.com/p/349315675

在这里简单重复下DeiTViT的不同之处:

DeiT可以通过一组很优秀的超参数使得在不改变任何架构的前提下实现了ViT的涨点。

图17:DeiT超参数的设置

当然,DeiT也可以改变架构,DeiT在训练时可加入distillation token,如下图18所示,如果你觉得这部分我讲的有点简单了,请参考原文或上面博客链接。

图18:使用蒸馏策略的DeiT结构

ViT的输出是一个softmax,它代表着预测结果属于各个类别的概率的分布。ViT的做法是直接将这个softmax与GT label取  。

而在DeiT中,除了这个 以外,还要加上一个蒸馏损失:

蒸馏分两种,一种是软蒸馏(soft distillation),另一种是硬蒸馏(hard distillation)。先讲讲软蒸馏,如下式所示,右半部分,和分别是student model和teacher model的输出,  表示  散度,表示softmax函数, 和  是超参数。

硬蒸馏如下式所示,  表示交叉熵。

简而言之,蒸馏的含义就是:学生网络的输出  与真实标签取  ,接着如果是硬蒸馏,就再与教师网络的标签取  。如果是软蒸馏,就再与教师网络的softmax输出结果取  。

值得注意的是,硬标签也可以通过标签平滑技术 (Label smoothing) 转换成软标签,其中真值对应的标签被认为具有  的概率,剩余的  由剩余的类别共享。  是一个超参数,这里取0.1。

训练示意图如下图19所示:

图19:DeiT训练流程

在训练过程中,class token来自ViT,保持不变。只是右侧加上了distillation token。

左侧class token的输出通过一个MLP head得到predicted label,与真实标签GT label取  。

右侧distillation token的输出通过一个MLP head得到predicted label,与教师模型的输出Teacher predicted label取软蒸馏损失 $\mathcal{L}\mathrm{global} 或硬蒸馏损失\mathcal{L}\mathrm{global}^\mathrm{hard Distill} $ 。

以上就是DeiT相对于ViT的变化,可以概括为2点:

  • 一个distillation token配上蒸馏损失。
  • 一组优秀的超参数。

18.1.2 正式介绍CaiT

Motivation

这个工作的Motivation也很直接,就是为了去优化Deeper Transformer。

本文所基于的核心论点是:网络架构 (architecture) 和优化 (architecture) 是相互影响,相互作用的。

换句话讲,一个网络优化的结果怎样,与其架构是密切相关的。架构能影响优化结果,而优化方式也可以决定架构。

作者先以著名的ResNet为例,其表达式可以写作:

其中,  一般是残差,即Identity操作,而  代表这个block的具体操作。

ResNet的设计就是上述核心论点的最好证明。kaiming大佬也在论文中这样讲到:

Residual networks do not offer a better representational power. They achieve better performance because they are easier to train.

也就是说,ResNet不具备更优秀的表征能力,只是使网络变得更易训练罢了。再次证明了网络架构 (architecture) 和优化 (architecture) 是相互影响,相互作用的。

再回到vision transformer上面来,其每个block的表达式可以写作:

现在的问题就是:如何对这个架构进行更好地归一化 (Normalize),训练权重 (Weigh) 和初始化 (Initialize)?

CaiT给出的答案,相对DeiT来讲,也可以概括为2点:

  • LayerScale: 使Deep Vision Transformer易于收敛,并能提高精度。
  • class-attention layers:高效的处理class token的方式。

和DeiT一样,依然是2个改进,很简单但是相当有用。

最重要的是,ResNet,DeiT,CaiT的这种“小改进”是真的能够给vision transformer带来性能提升的,非常solid的工作,代码开源完整,follower很多。而且这些改进你都可以很快地在你自己的工程中复现并且收获满意的结果。

具体方法1:使Deep Vision Transformer易于收敛,并能提高精度的LayerScale

LayerScale这个操作的目的也很直接,就是为了使得vision transformer在训练时更稳定。作者首先对比了如下图20所示的4种不同的transformer blocks的正则化策略,看看它们哪个有助于提升优化的效果。

图20:4种不同的transformer blocks的正则化策略

图20 (a) 中的操作:是ViT/DeiT中使用的pre-norm结构,即先进行Layer Normalization,再进行Self-attention或者FFN。其结果直接与这个block的输入相加。

图20 (b) 中的操作:引入了一个可学习的参数  ,作用在residual block的输出,并取消了Layer Normalization的操作。

ReZero方法把  初始化为0,Fixup方法把  初始化为1,但是作者在实验中发现这些方法都不会使训练收敛。

图20 (c) 中的操作是把(a) 和 (b)的操作结合在一起了,发现(c)的操作是可行的。把  初始化为一个比较小的数值,这种办法在网络较深时也可以收敛。

图20 (d) 中的操作就是本文提出的LayerScale操作了。

具体做法是保持Layer Normalization,并对Self-attention或者FFN的输出乘以一个对角矩阵。目的是给不同的channel乘以不同的  值。比如在channel这个维度,一共有  个channel,为它们分别乘以  以后再与原来的  相加。

式中的  和  都是可学习的参数。在18层之前,它们初始化为0.1。若网络更深,则在24层之前初始化为  。若网络更深,则在之后更深的网络中初始化为  。这样做的目的是希望使得每个block在一开始的时候更接近Identity mapping,在训练的过程中逐渐地学习到模块自身的function。作者通过实验证明,以这种方式训练网络更容易,而且这样做比 (c)的方式更加灵活。

LayerScale这种方式不会影响网络的表达能力,因为它也可以看成是融入了之前的Self-attention或者FFN模块中。

具体方法2:class-attention layers:高效的处理class token的方式

下面介绍详细的CaiT的结构:

下图21中最左侧是ViT的结构。作者先指出了ViT的一个问题,即:ViT在优化参数时,其实是要class token及其后续层的对应的class embedding同时做到2件事情,即:

1. 引导attention过程,帮助得到attention map。

2. 这些token还要输入到classifier中,完成分类任务。

作者认为class embedding要同时做到这2件事其实有点矛盾,参数会朝着矛盾的方向优化。所以作者想能不能把class token往后放,如下图middle所示。为啥要往后放呢?主要是想让class token不参与前面的Self-attention,使得参数的优化方向尽量不产生矛盾。

如果把这种思想再进一步,就得到了CaiT的结构,如图21右侧图所示。

  • self-attention与ViT一致,只是不再使用class token了。
  • 在网络的最后2层,self-attention变成了一个叫class-attention的东西。它依然是attention的功能,只是只更新最左侧的class embedding,而不更新其它的  个patch embedding (被freeze住了)。换句话说,class-attention这个模块只有patch embedding传递给class embedding的信息,而没有class embedding传递给patch embedding的信息。而且只有class embedding会传给FFN。
图21:CaiT的具体结构

仔细观察下这个class-attention层 (CA layer)的特点:patch embedding不更新,只更新class embedding,我们可以思考下作者为啥要这样设计,下面是我个人的理解:

首先你想完成分类任务,就得靠class embedding。class embedding后面一定要接MLP head完成分类任务。那么class embedding就得获得这个图片的全部信息。

而图片的信息都在patch embedding里面,所以我们需要一个layer来从patch embedding里面抽取信息给class embedding。所以这个CA层必须具有这个功能,即:

从patch embedding里面抽取信息给class embedding。

然后,按照上文的分析,为了使得参数的优化方向尽量不产生矛盾,class embedding就只干好分类的事就行,尽量别参与attention过程。你的任务就是分类的,不要参与其他活动比如attention过程,这样使得class embedding的角色更专一。所以:

class embedding的信息不能给patch embedding。

所以根据提炼出的这的2个要求,CA层的设计就容易理解了。

下面是把CA层用数学公式表达出来。

假设为了有  个head和  个patch,embedding dimension为  。

,计算CA层的  :

式中  。

式中  。它加在了  上面得到这个CA层的输出。

实际上CA layer使用2层就够了,前面依旧是12个正常的Transformer Block。

计算复杂度也不高,因为普通的self-attention计算  ,而CA计算  ,复杂度由  变成了  。

Experiments

实验的超参数设计方面,作者在DeiT的基础上做了一些微小的变动。为了加速训练并优化显存占用,作者使用了Fairscale library,fp16精度。

作者观测到依旧使用DeiT的超参数在训练深层的ViT模型时会出现不收敛的现象,很难训练,尤其是在仅仅使用ImageNet数据集的时候。

实验1:使训练更稳定的方法

  • 调整不同深度的drop-rate:

就是使得每一层的drop-rate随着层数的加深而线性增加。但是作者实验证明这种方法并没有太大的作用。最多训到18层深的ViT,若想把模型变得更深,则需要加大dropout-rate,如下图22左侧所示。实际实现时的不同层的dropout-rate设置为:

  • Normalization可以使深层ViT收敛:

这个地方1的实验还是在展开图20的 (b) 和 (c),(b)代表使用了Fixup 和 T-Fixup等技术,(c)代表作者又加上了Layer Normalization。作者发现只使用Fixup 和 T-Fixup等技术,模型无法收敛,只有加上了Layer Normalization以后才会收敛,甚至会提升模型性能,如下图22中间侧所示。

结果显示,Fixup 和 T-Fixup等技术在模型比较浅时与LayerScale效果相当,但是它们实现更复杂,因为需要为不同类型的layer设置不同的初始化方法。

  • LayerScale的作用:

作者训练了一个36层的模型,并计算了每一层的  ,这个值越大,代表残差块的作用越大,代表模型距离Identity越远,其结果如下图23所示。图23中的上面2个图代表Self-attention,图23中的下面2个图代表FFN。橘色的图代表不使用LayerScale的结果,而蓝色的图代表使用了LayerScale的结果。实验发现使用了LayerScale之后这个比例变得更加uniform了,这个好处也是模型更易训练的原因之一。

图22:使模型更易收敛的方法对比
图23:LayerScale的作用

实验2:Class-attention layers的作用

如下图24所示为CA层在不同设置下的性能对比。结论是使用12层Self-attention和2层CA可以得到最佳的性能。

图24:CA层在不同设置下的性能对比

不同尺寸的CaiT模型的架构:

如下图25所示为不同尺寸的CaiT模型的架构。左侧的24,36代表模型的层数。XXS,XS,S,M代表模型的大小。  为embedding dimension,为的是保持每个head的dimension为48不变 (注意在DeiT中每个head的dimension为64,CaiT稍微小一点)。

图25:不同尺寸的CaiT模型的架构

CaiT改变了一些超参数 (相比于DeiT):

Model CaiT DeiT
learning rate 0.0001 0.0005
warm up epochs 5 5
weight decay 0.05 0.05
learning rate decay cosine cosine

不同CaiT模型的dropout rate和 LayerScale的初始化参数  :

图26:不同CaiT模型的dropout rate和 LayerScale的初始化参数

实验3:与SOTA模型的对比

作者对比了近期出现的一些ViT模型的变体,如TNT,DeiT,T2T-ViT,以及一些基于CNN的模型EfficientNet等。CaiT-M36取得了SOTA的性能。它在ImageNet上取得了86.3%的Acc1 performance,相比于DeiT-B的85.2%来说取得了惊人的提升。

图27:与SOTA模型的对比

实验4:小数据集上的迁移学习效果

如下图28所示,是在小数据集上finetune的结果。对于CIFAR-100 和 CIFAR-10 数据集需要把learning rate减小100倍,其他数据集减小10倍。

图28:小数据集上的迁移学习效果

实验5:从DeiT渐变到CaiT

作为与前作DeiT的呼应,作者将DeiT-S模型逐渐变化到CaiT-36,看看每一步的变化对模型性能的影响,如下图29所示。其中最重要的方法LayerScale和CA层如灰色框所示,它们涨点明显。另外,在分辨率为384的数据集上进行finetune也是涨点的重要方法之一。

图29:从DeiT渐变到CaiT,每一步的变化对模型性能的影响

实验6:head数量对优化的影响

如下图30所示,虽然不同head数量的FLOP的数量大致相同,但在增加此head数量时,计算更加分散,在典型硬件上,这将导致有效吞吐量降低。实际选择8,以求提供了一个很好的精度和速度之间的trade-off。

图30:head数量对优化的影响

class-attention层可视化

如下图31所示为作者可视化模型仅有的2个CA层的attention map,分别摘取了4个head的结果。每张图的2行分别为第1个CA层的attention map和第2层的CA层的attention map。

之所以要可视化CA层的attention map,是因为CA层是CaiT中唯一使得class embedding与patch embedding相互作用的层。观察发现:

  • 第1个class-attention层关注图片中的object多一点,不同的head关注相同会不同的部位。
  • 第2个class-attention层关注图片的背景信息或全局信息。
图31:class-attention层可视化

总结

本文提出了2种优化ViT的方式:

  • LayerScale: 使Deep Vision Transformer易于收敛,并能提高精度。
  • class-attention layers:高效的处理class token的方式。

LayerScale解决了训练Deep vision transformer的问题,通过它使得深层ViT更易于训练。class-attention layers使得class embedding的职能更加专门化,使得参数的优化方向尽量不产生矛盾,就是为了完成分类任务,使得训练目标更明确。

  • 18.2 CaiT代码解读:

CaiTDeiT一样都是来自Facebook的同一团队,一作都是Hugo Touvron大佬。所以它们的代码共用了1个github仓,位置在:

https://github.com/facebookresearch/deit

关于DeiT论文,论文和代码的详细解读请读者们参考下文:

Training data-efficient image transformers & distillation through attention

https://arxiv.org/abs/2012.12877

科技猛兽:Vision Transformer 超详细解读 (原理分析+代码解读) (三)

https://zhuanlan.zhihu.com/p/349315675

此处就只介绍CaiT的代码了:

打开上面的git仓,可以看到代码结构为:

图32:DeiT代码结构

发现相比于DeiT,多了一个cait_models.py,这个应该就是CaiT的模型结构了。

入口文件:

通常看一个github要先看代码的入口文件 (或者叫启动文件),不论你是在 云上GPU集群训练 (多节点) 的,还是在本地小服务器上 (单节点) 训练的,都需要入口文件。很明显这里面的main.pyrun_with_submitit.py就是入口文件了。

比如说在本地的小服务器 (比如说高校实验室的8卡服务器) 上训练,你可以使用下面的指令自己吃4张卡开始单机多卡训练DeiT-small模型:

python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --model deit_small_patch16_224 --batch-size 256 --data-path /path/to/imagenet --output_dir /path/to/save

再比如说在企业的云服务器 (比如说华为的ModelArts) 上训练,你可以使用下面的指令自己吃16张卡 (2个node,8 gpus per node,默认值) 开始在 384 resolution images 上多机多卡fine-tune DeiT-base模型:

pip install submitit
python run_with_submitit.py --model deit_base_patch16_384 --batch-size 32 --finetune https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth --input-size 384 --use_volta32 --nodes 2 --lr 5e-6 --weight-decay 1e-8 --epochs 30 --min-lr 5e-6

这个run_with_submitit.py也是会调用 main.py里面的main函数

这个入口文件main.py大量使用了timm库的实现。timm库的链接和解析如下:

https://github.com/rwightman/pytorch-image-models/tree/master/timm/models

科技猛兽:视觉Transformer优秀开源工作:timm库vision transformer代码解读

https://zhuanlan.zhihu.com/p/350837279

from timm.data import Mixupfrom timm.models import create_modelfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropyfrom timm.scheduler import create_schedulerfrom timm.optim import create_optimizerfrom timm.utils import NativeScaler, get_state_dict, ModelEma


losses.py定义蒸馏损失的这个类,直接调用即可:

from losses import DistillationLoss


datasets.py定义数据集的类,使用时直接:

from datasets import build_dataset
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)dataset_val, _ = build_dataset(is_train=False, args=args)


samplers.py定义采样器:

from samplers import RASampler


utils.py定义杂七杂八的工具函数:

import utils


main.py首先传入相关的命令行参数:

def get_args_parser(): parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) parser.add_argument('--batch-size', default=64, type=int) parser.add_argument('--epochs', default=300, type=int)
# Model parameters parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--input-size', default=224, type=int, help='images input size')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)') parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', help='Drop path rate (default: 0.1)')
parser.add_argument('--model-ema', action='store_true') parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') parser.set_defaults(model_ema=True) parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
... ...


再定义dataloader:

if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val)


定义dataloader:

用上之前创建的dataset_train,dataset_val,sampler_train,sampler_val和相关命令行参数。

data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, )
data_loader_val = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, batch_size=int(1.5 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False )


创建模型:

create_model函数来自timm.models文件。

print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, )


如果是在预训练模型的基础上fine-tune:

if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.finetune, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.finetune, map_location='cpu')
checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k]
# interpolate position embedding pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches
# num_extra_tokens: 1,2 num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding # H,W orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)

# height (== width) for the new position embedding # new H,W new_size = int(num_patches ** 0.5)
# extra_tokens: 1,2 # class_token and dist_token are kept unchanged extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# extra_tokens: N # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# (b, d, H, W) pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
# (b, new H * new W, d) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
# (b, new N + 1, d) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed
model.load_state_dict(checkpoint_model, strict=False)
model.to(device)


如果有蒸馏,导入教师模型:

teacher_model = None if args.distillation_type != 'none': assert args.teacher_path, 'need to specify teacher-path when using distillation' print(f"Creating teacher model: {args.teacher_model}") teacher_model = create_model( args.teacher_model, pretrained=False, num_classes=args.nb_classes, global_pool='avg', ) if args.teacher_path.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.teacher_path, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.teacher_path, map_location='cpu') teacher_model.load_state_dict(checkpoint['model']) teacher_model.to(device) teacher_model.eval()


如果有训练好的checkpoint,则都直接加载进去,包括模型,优化器,学习率变化策略,初始epoch数等:

checkpoint['model'] → model_without_ddp
checkpoint['optimizer'] → optimizer
checkpoint['lr_scheduler'] → lr_scheduler
checkpoint['epoch'] → start_epoch
checkpoint['model_ema'] → model_ema
checkpoint['scaler'] → loss_scaler

if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler'])


测试模型:

if args.eval: test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") return


训练代码:

print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, set_training_mode=args.finetune == '' # keep in eval mode during finetuning )
lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path)
test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%')
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters}
if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))


接下来就到CaiT模型的代码cait_model.py了:

导入必要的库,大多来自timm library:

import torchimport torch.nn as nnfrom functools import partial
from timm.models.vision_transformer import Mlp, PatchEmbed , _cfgfrom timm.models.registry import register_modelfrom timm.models.layers import trunc_normal_


所有模型的名称:

__all__ = [ 'cait_M48', 'cait_M36', 'cait_M4', 'cait_S36', 'cait_S24','cait_S24_224', 'cait_XS24','cait_XXS24','cait_XXS24_224', 'cait_XXS36','cait_XXS36_224']


class attention模块:

图21以及式(18.5, 18.6, 18.7)中的CA层:

其中各个变量的维度已在代码中标注出来了。
注意forward函数里面的输入是2个:x, x_cls, 只有其他的patch embedding的信息传给x_cls,而不会反向传。

class Class_Attention(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to do CA def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias) self.k = nn.Linear(dim, dim, bias=qkv_bias) self.v = nn.Linear(dim, dim, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x ): B, N, C = x.shape
# q: (b, num_heads, 1, dim/num_heads) q = self.q(x[:,0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# k: (b, num_heads, N, dim/num_heads) k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q = q * self.scale
# v: (b, num_heads, N, dim/num_heads) v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# attn: (b, num_heads, 1, N) attn = (q @ k.transpose(-2, -1)) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn)
# # v: (b, num_heads, 1, dim/num_heads) --> (b, 1, dim) x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) x_cls = self.proj(x_cls) x_cls = self.proj_drop(x_cls) # x_cls: (b, 1, dim) return x_cls


LayerScale模块 (使用上面描述的class attention):

如图20(d)所示。

矩阵的定义就是使用了nn.Parameter操作,那么attention/FFN做完以后先要乘上这个  对角矩阵,再与输入求和。

class LayerScale_Block_CA(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add CA and LayerScale def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block = Class_Attention, Mlp_block=Mlp,init_values=1e-4): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention_block( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
def forward(self, x, x_cls): u = torch.cat((x_cls,x),dim=1) x_cls = x_cls + self.drop_path(self.gamma_1 * self.attn(self.norm1(u))) x_cls = x_cls + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x_cls))) return x_cls


正常的attention操作和对应的LayerScale (使用正常的attention):

class Attention_talking_head(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf) def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_l = nn.Linear(num_heads, num_heads) self.proj_w = nn.Linear(num_heads, num_heads) self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0] * self.scale , qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) attn = self.proj_l(attn.permute(0,2,3,1)).permute(0,3,1,2) attn = attn.softmax(dim=-1) attn = self.proj_w(attn.permute(0,2,3,1)).permute(0,3,1,2) attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class LayerScale_Block(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add layerScale def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention_talking_head, Mlp_block=Mlp,init_values=1e-4): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention_block( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
def forward(self, x): x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x


正式定义CaiT模型:

depth 代表正常block的数量。
depth_token_only 代表CA层的数量。
Attention_block 代表正常self-attention
Attention_block_token_only 代表CA层的attention
block_layers 代表使用LayerScale和正常attention的block。
block_layers_token 代表使用LayerScale和CA层的block。

class cait_models(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to adapt to our cait models def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, global_pool=None, block_layers = LayerScale_Block, block_layers_token = LayerScale_Block_CA, Patch_layer=PatchEmbed,act_layer=nn.GELU, Attention_block = Attention_talking_head,Mlp_block=Mlp, init_scale=1e-4, Attention_block_token_only=Class_Attention, Mlp_block_token_only= Mlp, depth_token_only=2, mlp_ratio_clstk = 4.0): super().__init__()


定义2种不同的block:

self.blocks = nn.ModuleList([ block_layers( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer,Attention_block=Attention_block,Mlp_block=Mlp_block,init_values=init_scale) for i in range(depth)])
self.blocks_token_only = nn.ModuleList([ block_layers_token( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer, act_layer=act_layer,Attention_block=Attention_block_token_only, Mlp_block=Mlp_block_token_only,init_values=init_scale) for i in range(depth_token_only)])


初始化参数:

def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)


前向传播:

def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) x = x + self.pos_embed x = self.pos_drop(x)
for i , blk in enumerate(self.blocks): x = blk(x) for i , blk in enumerate(self.blocks_token_only): cls_tokens = blk(x,cls_tokens)
x = torch.cat((cls_tokens, x), dim=1) x = self.norm(x) return x[:, 0]
def forward(self, x): x = self.forward_features(x) x = self.head(x)
return x


通过register_model注册不同大小的模型,这里以cait_XXS24_224为例:

@register_modeldef cait_XXS24_224(pretrained=False, **kwargs): model = cait_models( img_size= 224,patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2,**kwargs) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/deit/XXS24_224.pth", map_location="cpu", check_hash=True ) checkpoint_no_module = {} for k in model.state_dict().keys(): checkpoint_no_module[k] = checkpoint["model"]['module.'+k] model.load_state_dict(checkpoint_no_module) return model


register_model这个函数来自timm库model文件夹下的registry.py文件,它的作用是:
@ 指装饰器
@register_model代表注册器,注册这个新定义的模型。
存储到_model_entrypoints这个字典中,比如

_model_entrypoints[cait_XXS24_224] = cait_models( img_size= 224,patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2,**kwargs)


然后在factory.pycreate_model函数中的下面这几行真正创建模型,你以后想创建的任何模型都会使用create_model这个函数,这里说清楚了为什么要用它:

if is_model(model_name): create_fn = model_entrypoint(model_name) else: raise RuntimeError('Unknown model (%s)' % model_name)
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): model = create_fn(pretrained=pretrained, **kwargs)


比如刚才在main.py里面用了create_model创建模型,如下面代码所示。而create_model就来自factory.py:

model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, )


总结

本文介绍的两个工作工作的初衷是为了加深Transformer模型。在模型加深的过程中可以同时保证稳定训练和精度提升。深度学习的基本思想之一就是让模型更深一点,这在CNN中已经很好的体现。

DeepViT所基于的论点是作者发现随着模型的加深,以一个32层模型为例,深层的 (17-32层)的不同层之间的attention map的相似度很高,然而不同的head之间的相似度却很低。基于此,作者通过一个转移矩阵,把不同的head的信息给结合起来了。借此,使得深层的block里面,不同的层训练出来的这个转移矩阵是不同的,达到异化attention map的目的。所以,DeepViT能够有效缓解注意力坍塌,使得ViT模型的性能可以随着层数的加深而增加。

CaiT提出了2种优化ViT的方式:

  • LayerScale: 使Deep Vision Transformer易于收敛,并能提高精度。
  • class-attention layers:高效的处理class token的方式。

LayerScale解决了训练Deep vision transformer的问题,通过它使得深层ViT更易于训练。class-attention layers使得class embedding的职能更加专门化,使得参数的优化方向尽量不产生矛盾,就是为了完成分类任务,使得训练目标更明确。

(0)

相关推荐