ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

PyTorch深度学习入门笔记(十二)神经网络-非线性激活

2022-01-27 21:03:42  阅读:166  来源: 互联网

标签:入门 nn import self torch 神经网络 PyTorch output input


课程学习笔记,课程链接
学习笔记同步发布在我的个人网站上,欢迎来访查看。

一、非线性激活常用函数介绍

非线性激活的目的是为了给我们的神经网络引入一些非线性的特质。
依然是打开官方文档
在这里插入图片描述
比较常用的函数是 nn.ReLu:

1.1 ReLU

在这里插入图片描述

对应的函数图是:
在这里插入图片描述
参数:inplace=True时,会修改 input 为非线性激活后的结果;inplace=False时,则不会修改 input ,input仍然为原值。
在这里插入图片描述
示例代码:

import torch
from torch import nn
from torch.nn import ReLU

input = torch.tensor([[1, -0.5],
                      [-1, 3]])
output = torch.reshape(input, (-1, 1, 2, 2))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.relu1 = ReLU()

    def forward(self, input):
        output = self.relu1(input)
        return output

net1 = Net()
output = net1(input)
print(output)

输出:
在这里插入图片描述
可以看到输入为 输出为
[[1, -0.5], [[1, 0],
[-1, 3]], [0, 3]]
输入传给了 ReLU函数,进行了截断,-0.5和-1均小于0,所以对应的输出结果为0。

1.2 Sigmoid

在这里插入图片描述
函数图:
在这里插入图片描述
示例代码:

import torch
import torchvision
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10("./dataset", train=False, download=True,
                                       transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.Sigm = Sigmoid()

    def forward(self, input):
        output = self.Sigm(input)
        return output

net1 = Net()

writer = SummaryWriter("logs")
step = 0
for data in dataloader:
    imgs, targets = data
    writer.add_images("input", imgs, global_step=step)
    output = net1(imgs)
    writer.add_images("output", output, step)
    step += 1

writer.close()

tensorboard 查看输出结果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

可以看到非线性激活主要目的就是给网络增加非线性特征,以便训练出符合要求的泛化模型。

标签:入门,nn,import,self,torch,神经网络,PyTorch,output,input
来源: https://blog.csdn.net/qq_44447544/article/details/122722638

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有