ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

PyTorch笔记--交叉熵损失函数实现

2021-08-10 15:02:16  阅读:537  来源: 互联网

标签:log -- pred torch 笔记 PyTorch softmax logits 函数


交叉熵(cross entropy):用于度量两个概率分布间的差异信息。交叉熵越小,代表这两个分布越接近。

函数表示(这是使用softmax作为激活函数的损失函数表示):

是真实值,是预测值。)

命名说明:

pred=F.softmax(logits),logits是softmax函数的输入,pred代表预测值,是softmax函数的输出。

pred_log=F.log_softmax(logits),pred_log代表对预测值再取对数后的结果。也就是将logits作为log_softmax()函数的输入。

方法一,使用log_softmax()+nll_loss()实现

torch.nn.functional.log_softmax(input)

  对输入使用softmax函数计算,再取对数。

torch.nn.functional.nll_loss(input, target)

  input是经log_softmax()函数处理后的结果,pred_log

  target代表的是真实值。

  有了这两个输入后,该函数对其实现交叉熵损失函数的计算,即上面公式中的L。

>>> import torch
>>> import torch.nn.functional as F
>>> x = torch.randn(1, 28)
>>> w = torch.randn(10,28)
>>> logits = x @ w.t()
>>> pred_log = F.log_softmax(logits, dim=1)
>>> pred_log
tensor([[ -0.8779,  -6.7271,  -9.1801,  -6.8515,  -9.6900,  -6.3061,  -3.7304,
          -8.1933, -11.5704,  -0.5873]])
>>> F.nll_loss(pred_log, torch.tensor([3]))
tensor(6.8515)

logits的维度是(1, 10)这里可以理解成是1个输入,最终可能得到10个分类的结果中的一个。pred_log就是

这里的参数target=torch.tensor([3]),我的理解是,他代表真正的分类的值是在第3类(从0编号)。

使用独热编码代表真实值是[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],即这个输入它是属于第三类的。

根据上述公式进行计算,现在我们 都已经知道了。

对其进行点乘操作

 

 

 

 方法二,使用cross_entropy()实现

torch.nn.functional.cross_entropy(input, target)

  这里的input是没有经过处理的logits,这个函数会自动根据logits计算出pred_log

  target是真实值

>>> import torch
>>> import torch.nn.functional as F
>>> x = torch.randn(1, 28)
>>> w = torch.randn(10,28)
>>> logits = x @ w.t()
>>> F.cross_entropy(logits, torch.tensor([3]))
tensor(6.8515)

这里我删除了上面使用方法一的代码部分,x和w没有重新随机生成,所以计算结果是一样的。

还在学习过程,做此纪录,如有不对,请指正。

标签:log,--,pred,torch,笔记,PyTorch,softmax,logits,函数
来源: https://www.cnblogs.com/xxmrecord/p/15123626.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有