ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

【pytorch】交叉熵损失函数 F.cross_entropy()

2022-02-04 19:34:12  阅读:262  来源: 互联网

标签:loss log cross pytorch entropy softmax soft math out


F.cross_entropy(x,y)

  cross_entropy(x,y)是交叉熵损失函数,一般用于在全连接层之后,做loss的计算。

  其中x是二维张量,是全连接层的输出;y是样本标签值。x[batch_size,type_num];y[batch_size]。

  cross_entropy(x,y)计算结果是一个小数,表示loss的值。

举例说明

x = np.array([[1, 2,3,4,5],#共三3样本,有5个类别
              [1, 2,3,4,5],
              [1, 2,3,4,5]]).astype(np.float32)
y = np.array([1, 1, 0])#这3个样本的标签分别是1,1,0即两个是第2类,一个是第1类
x = torch.from_numpy(x)
y = torch.from_numpy(y).long()
 
soft_out = F.softmax(x,dim=1)#给每个样本的pred向量做指数归一化---softmax

log_soft_out = torch.log(soft_out)#将上面得到的归一化的向量再point-wise取对数

loss = F.nll_loss(log_soft_out, y)#将归一化且取对数后的张量根据标签求和,实际就是计算loss的过程
 
"""
这里的loss计算式根据batch_size归一化后的,即是一个batch的平均单样本的损失,迭代一次模型对一个样本平均损失。
在多个epoch训练时,还会求每个epoch内的总损失,用于衡量epoch之间模型性能的提升。
"""

print(soft_out)
print(log_soft_out)
print(loss)
   
loss = F.cross_entropy(x, y)
print(loss)

输出:

  softmax:
  tensor([[0.0117, 0.0317, 0.0861, 0.2341, 0.6364],
  [0.0117, 0.0317, 0.0861, 0.2341, 0.6364],
  [0.0117, 0.0317, 0.0861, 0.2341, 0.6364]])

  tensor([[-4.4519, -3.4519, -2.4519, -1.4519, -0.4519],
  [-4.4519, -3.4519, -2.4519, -1.4519, -0.4519],
  [-4.4519, -3.4519, -2.4519, -1.4519, -0.4519]])

  tensor(3.7852)

结果分析

F.softmax(x,dim=1):一行和为1 sum([0.0117, 0.0317, 0.0861, 0.2341, 0.6364])=1

softmax函数公式

torch.log(soft_out):对softmax的结果进行取对数

a =pow(math.e,1)/(pow(math.e,1)+pow(math.e,2)+pow(math.e,3)+pow(math.e,4)+pow(math.e,5)) # 0.011656230956039609近似0.0117

print(math.log(0.011656230956039609)) # -4.4519

F.nll_loss(log_soft_out, y):对取对数的结果,根据y的值,(y值是索引),找到对应的值,黄色部分,各自取相反数再相加,求平均

(3.4519+3.4519+4.4519)/3 = 3.7852

所以:

cross_entropy函数:softmax->log->nll_loss

特别注意⭐⭐

  全连接层的输出形状为[batch_size,type_num]。含义是第i个样本为第j类的概率。

标签:loss,log,cross,pytorch,entropy,softmax,soft,math,out
来源: https://blog.csdn.net/qq_43592352/article/details/122784389

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有