ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

【ARXIV2205】EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers

2022-06-18 17:02:13  阅读:183  来源: 互联网

标签:__ dim scale Transformers weight CNNs self EdgeViTs attn


【ARXIV2205】EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers
91/100
发布文章
gaopursuit
未选择文件
请添加图片描述

【ARXIV2205】EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers

论文:https://arxiv.org/abs/2205.03436

代码:https://github.com/1hunters/EdgeViT

基于自注意力机制的视觉Transformer(ViT)在视觉任务上,已经形成和CNN一样强有力的架构,但其计算量和模型大小很大。虽然一些工作通过引入先验信息或级联多阶段结构到ViT中,但在移动设备上仍不够高效。本文研究基于MobileNetV2的轻量化ViT,其通过引入局部-全局-局部(LGL)的bottleneck实现,其结合了注意力机制和CNN的优势。

作者提出了将VIT模型应用于移动端需要考虑的三个问题:

  • 1)推理速度要快。 当前的一些指标如 FLOPs 难以反映模型在移动端的速度,因为内存访问速度、并行性等因素还要综合考虑。
  • 2)模型可以大。 当前手机可以拥有32GB的内存,存储模型并应该做为限制因素。
  • 3)实现的友好性。 SWIN里的 cyclic shift 不便于在移动端实现,因此模型要考虑是否便于在移动端实现。

以上面三个原则为指导,作者提出了 EdgeViTs,设计了一个高效的 局部-全局-局部(LGL) 模块,能够实现更好的准确性和计算效率。

请添加图片描述

模型如上图所示,重点是其中的LGL模块,包括个关键部分:

  • local aggregation: 由卷积和 depth conv 组成
  • global sparse attention: 平均池化后进行注意力计算
  • local propagation: 使用反卷积将缩小的特征图恢复到原来大小。

下面看具体代码,理解起来没有什么难度。

class LocalAgg():
    def __init__(self, dim):
        self.conv1 = Conv2d(dim, dim, 1)
        self.conv2 = Conv2d(im, dim, 3, padding=1, groups=dim)
        self.conv3 = Conv2d(dim, dim, 1)
        self.norm1 = BatchNorm2d(dim)
        self.norm2 = BatchNorm2d(dim)
     forward(self, x):
        x = self.conv1(self.norm1(x))
        x = self.conv2(x)
        x = self.conv3(self.norm2(x))
        return x

class GlobalSparseAttn():
    def __init__(self, dim, sample_rate, scale):
        self.scale = scale
        self.qkv = Linear(dim, dim * 3)
        self.sampler = AvgPool2d(1, stride=sample_rate)
        kernel_size=sr_ratio
        self.LocalProp = ConvTranspose2d(dim, dim, kernel_size, stride=sample_rate, groups=dim
        )
        self.norm = LayerNorm(dim)
        self.proj = Linear(dim, dim)
    def forward(self, x):
        x = self.sampler(x)
        q, k, v = self.qkv(x)
        attn = q @ k * self.scale
        attn = attn.softmax(dim=-1)
        x = attn @ v
        x = self.LocalProp(x)
        x = self.proj(self.norm(x))
        return x

其实,网络整体就是基于CNN的,只不过沿用了 SWIN 的典型架构。实验结果如下表所示。尽管作者说,与MobileViTs相比,EdgeViTs在三种复杂度设置下分别实现了5.4%、2.8%和2.7%的提高,但是我感觉从FLOPs等指标来看,并没有约对的优势。这里是我的个人理解,有不同意见的地方可以随时交流。

请添加图片描述

标签:__,dim,scale,Transformers,weight,CNNs,self,EdgeViTs,attn
来源: https://www.cnblogs.com/gaopursuit/p/16388674.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有