ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

pytorch F.cross_entropy(x,y)理解

2020-05-26 20:03:12  阅读:553  来源: 互联网

标签:loss log pow cross pytorch entropy soft math out


F.cross_entropy(x,y)

 1 x = np.array([[1, 2,3,4,5],
 2              [1, 2,3,4,5],
 3              [1, 2,3,4,5]]).astype(np.float32)
 4 y = np.array([1, 1, 0])
 5 x = torch.from_numpy(x)
 6 y = torch.from_numpy(y).long()
 7 
 8 soft_out = F.softmax(x,dim=1)
 9 log_soft_out = torch.log(soft_out)
10 loss = F.nll_loss(log_soft_out, y)
11 print(soft_out)
12 print(log_soft_out)
13 print(loss)
14   
15 loss = F.cross_entropy(x, y)
16 print(loss)

结果:

softmax:

tensor([[0.0117, 0.0317, 0.0861, 0.2341, 0.6364],
[0.0117, 0.0317, 0.0861, 0.2341, 0.6364],
[0.0117, 0.0317, 0.0861, 0.2341, 0.6364]])


tensor([[-4.4519, -3.4519, -2.4519, -1.4519, -0.4519],
[-4.4519, -3.4519, -2.4519, -1.4519, -0.4519],
[-4.4519, -3.4519, -2.4519, -1.4519, -0.4519]])


tensor(3.7852)
tensor(3.7852)

结果分析:

F.softmax(x,dim=1):一行和为1 sum([0.0117, 0.0317, 0.0861, 0.2341, 0.6364])=1
softmax函数公式

 

torch.log(soft_out):对softmax的结果进行取对数
a =pow(math.e,1)/(pow(math.e,1)+pow(math.e,2)+pow(math.e,3)+pow(math.e,4)+pow(math.e,5)) # 0.011656230956039609近似0.0117
print(math.log(0.011656230956039609)) # -4.4519
F.nll_loss(log_soft_out, y):对取对数的结果,根据y的值,(y值是索引),找到对应的值,黄色部分,各自取相反数再相加,求平均
(3.4519+3.4519+4.4519)/3 = 3.7852
所以:
cross_entropy函数:softmax->log->nll_loss
参考链接:
https://blog.csdn.net/qq_22210253/article/details/85229988
https://blog.csdn.net/wuliBob/article/details/104119616
https://blog.csdn.net/weixin_38314865/article/details/104487587?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.nonecase&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.nonecase


 

 

 

标签:loss,log,pow,cross,pytorch,entropy,soft,math,out
来源: https://www.cnblogs.com/shuangcao/p/12968336.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有