ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

PyTorch 深度学习实践 第4讲:反向传播

2022-08-02 17:01:39  阅读:182  来源: 互联网

标签:loss 梯度 grad PyTorch 反向 计算 深度 forward data


反向传播(Back Propagation):

视频教程

1.代码说明:
  • forward 计算loss
  • backward 反向计算梯度
  • 由sgd再更新W权重
import torch

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

w = torch.tensor([1.0])#选择权重,w=【1.0】
w.requires_grad = True#提醒w需要计算梯度

def forward(x):
    return x * w
#w是tensor,二者相乘,x自动类型转换成tensor, x*w输出构建计算图(w需要计算梯度,所以输出结果也需要计算梯度)


# 损失函数:动态构建计算图
def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2
构建的计算图:

image

print("predict (before training)",4, forward(4).item())

for epoch in range(100):#进行100轮
    for x, y in zip(x_data,y_data):#每次从x_data,y_data 中组一个样本(x,y)
        l = loss(x,y) #forward:计算样本损失  
        l.backward()#backward:从loss 开始,将链路上所有需要计算梯度的值自动计算出
        print('\tgrad:',x,y,w.grad.item())#将计算的梯度存到grad中,item将梯度中数值取出来变成标量,防止产生计算图
        w.data = w.data - 0.01 * w.grad.data
        #将权重的数值进行修改,注意,此处就是纯计算,用.data,不是张量求梯度
        w.grad.data.zero_()# 权重梯度的数据全部清0,保证w改变,得到相应梯度
    print('progress:',epoch,l.item())#输出轮数与 损失值
print("predict(after training)",4, forward(4).item())
注意:
  1. w是Tensor(张量类型),Tensor包含data,grad,他们都是Tensor,grad初始为None,调用l.backward()方法后w.grad为Tensor,故更新w.data时需使用w.grad.data。(此处理解:张量用来构建计算图,更改W的值得用.data
  2. 在构建计算图中,与W相关的Tensor都需要求梯度
  3. l.backward()会把计算图中所有需要梯度(grad)的地方都会求出来,然后把梯度都存在对应的待求的参数中,最终计算图被释放。
  4. 更新权重,记得要设置grad为0,

2.计算说明:

  • forward:给了x与权重的量,在f模块中计算局部梯度,向前层层得出z的值,计算loss值。
  • backward:计算loss与输出量z的偏导,一步步向前计算l与x,w的导数,一步步传播梯度。
    image
    image

标签:loss,梯度,grad,PyTorch,反向,计算,深度,forward,data
来源: https://www.cnblogs.com/Ling-22/p/16544305.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有