ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

简洁易懂的PyTorch版ResNet50复现代码

2021-02-09 14:58:53  阅读:370  来源: 互联网

标签:__ ResNet50 nn self channels stride PyTorch 易懂 out


ResNet50网络架构

ResNet50的网络解构相对简单,没有涉及到复杂的组件,大概50行代码就能复现。但我每次想用它的时候都会忘点东西,比如Bottleneck的结构如何实现,ResNet50的几个阶段各包含几个块等等,想着得写一篇文章记录下,免得以后又重复搬砖。ResNet50的网络结构如下,论文中网络的输入为 3x224x224,先经过步长为 2 填充为 3 的 7x7 卷积 + BN + ReLU和步长为 2 填充为 1 的 3x3 最大池化,接着经过4个阶段,每个阶段包含的 Bottleneck 卷积块分别为3、4、6、3,最后经过步长为 1 填充为 0 的 7x7 均值池化、Flatten 和输入为 2048 维,输出为 1000 维的全连接层,经过 Softmax 操作后得到网络的分类概率预测。
ResNet50结构

Bottleneck卷积块

Bottleneck卷积块是ResNet50核心的部分,ResNet50的每个阶段由若干Bottleneck组成,其中第一个Bottleneck的输入与输出通道数不一致,需要使用 1x1 卷积 + BN 映射 Shortcut 后相加,其余的Bottleneck则是直接将 Shortcut 进行相加。包含与不包含1x1映射的Bottleneck结构分别如下所示:
Bottleneck结构

PyTorch复现代码

# ResNet50.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, groups=1, activation=True):
        super(Conv, self).__init__()
        padding = kernel_size // 2 if padding is None else padding
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
                              padding, groups=groups, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU(inplace=True) if activation else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, down_sample=False, groups=1):
        super(Bottleneck, self).__init__()
        stride = 2 if down_sample else 1
        mid_channels = out_channels // 4
        self.shortcut = Conv(in_channels, out_channels, kernel_size=1, stride=stride, activation=False) \
            if in_channels != out_channels else nn.Identity()
        self.conv = nn.Sequential(*[
            Conv(in_channels, mid_channels, kernel_size=1, stride=1),
            Conv(mid_channels, mid_channels, kernel_size=3, stride=stride, groups=groups),
            Conv(mid_channels, out_channels, kernel_size=1, stride=1, activation=False)
        ])

    def forward(self, x):
        y = self.conv(x) + self.shortcut(x)
        return F.relu(y, inplace=True)

class ResNet50(nn.Module):
    def __init__(self, num_classes):
        super(ResNet50, self).__init__()
        self.stem = nn.Sequential(*[
            Conv(3, 64, kernel_size=7, stride=2),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        ])
        self.stages = nn.Sequential(*[
            self._make_stage(64, 256, down_sample=False, num_blocks=3),
            self._make_stage(256, 512, down_sample=True, num_blocks=4),
            self._make_stage(512, 1024, down_sample=True, num_blocks=6),
            self._make_stage(1024, 2048, down_sample=True, num_blocks=3),
        ])
        self.head = nn.Sequential(*[
            nn.AvgPool2d(kernel_size=7, stride=1, padding=0),
            nn.Flatten(start_dim=1, end_dim=-1),
            nn.Linear(2048, num_classes)
        ])

    def _make_stage(self, in_channels, out_channels, down_sample, num_blocks):
        layers = [Bottleneck(in_channels, out_channels, down_sample=down_sample)]
        for _ in range(1, num_blocks):
            layers.append(Bottleneck(out_channels, out_channels, down_sample=False))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.head(self.stages(self.stem(x)))

if __name__ == "__main__":
    inputs = torch.rand((8, 3, 224, 224)).cuda()
    model = ResNet50(num_classes=1000).cuda().train()
    outputs = model(inputs)
    print(outputs.shape)

标签:__,ResNet50,nn,self,channels,stride,PyTorch,易懂,out
来源: https://blog.csdn.net/hlld__/article/details/113755368

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有