ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

神经网络——损失函数、反向传播与优化器

2022-01-27 09:02:09  阅读:178  来源: 互联网

标签:loss 函数 kernel torch lh 神经网络 反向 result size


loss

loss越小越好

  • 计算实际输出和目标之间的差距
  • 为我们更新输出提供一定的依据(反向传播)

调用torch中已有损失函数:

result_loss = loss(output, target)

backward

反向传播:计算每一个参数的梯度

result_loss.backward()

优化器

注意:需要清除之前的梯度值

实例

import torch.optim
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader

# 准备数据集
dataset = torchvision.datasets.CIFAR10("../pytorch_learn/dataset2", train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=1)

#  创建一个神经网络
class lh(nn.Module):
    def __init__(self):
        super(lh, self).__init__()

        # sequential用法
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

loss = nn.CrossEntropyLoss()
# 网络
lh = lh()
# 优化器
optim = torch.optim.SGD(lh.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        img, target = data
        # 网络输出
        output = lh(img)
        # loss
        result_loss = loss(output, target)
        # 清除之前的梯度
        optim.zero_grad()
        # 梯度
        result_loss.backward()
        # 优化
        optim.step()
        running_loss = running_loss + result_loss
    print(running_loss)

标签:loss,函数,kernel,torch,lh,神经网络,反向,result,size
来源: https://blog.csdn.net/qq_42806080/article/details/122709788

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有