目录

•1、网络整体框架

•2 、Patch Merging

•3 、W-MSA

MSA模块计算量

W-MSA模块计算量

•4、 SW-MSA

•5 、Relative Position Bias


1、网络整体框架

 

2 Patch Merging

这里看着挺复杂,其实就相当于先对特征图进行LayerNorm,然后再进行一个卷积核大小为2×2,步距为2的深度可分离卷积。 

3 W-MSA

MSA模块计算量

W-MSA模块计算量

4 SW-MSA

采用W-MSA模块时,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了SW-MSA模块,即进行偏移的W-MSA。可以理解成窗口从左上角分别向右侧和下方各偏移了M/2 

可以发现通过将窗口进行偏移后,由原来的4个窗口变成9个窗口了。后面又要对每个窗口内部进行MSA,为了避免进行太多的窗口多头自注意力

 为了防止不同窗口之间的信息乱窜,在实际计算中使用的是masked MSA即带蒙板mask的MSA,这样就能够通过设置蒙板来隔绝不同区域的信息。

5 Relative Position Bias

这里描述的是相对位置索引,也就是相对位置关系,并不是相对位置偏置参数。可以根据相对位置索引去获取对应的参数。关键是怎么根据位置索引获取相对位置偏置参数? 

为了方便把二维索引转成一维索引。但如果将行标和列表直接简单相加会出现问题。比如相对位置索引中有(0 , -1)和(-1 , 0) 在二维的相对位置索引中明显是代表不同的位置,但如果简单相加都等于-1那就出问题了。

 

这样每个位置就得到了自己唯一的相对位置索引 

我们可以创建一个可训练的相对位置偏置列表,在列表之找到对应的相对位置偏置。 

Logo

尧米是由西云算力与CSDN联合运营的AI算力和模型开源社区品牌,为基于DaModel智算平台的AI应用企业和泛AI开发者提供技术交流与成果转化平台。

更多推荐