ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

pytorch API

2022-06-01 15:05:08  阅读:190  来源: 互联网

标签:nn pred torch label np pytorch API print


  1. pytorch--多标签分类损失函数
import torch
import numpy as np

pred = np.array([[-0.4089, -1.2471, 0.5907],
                [-0.4897, -0.8267, -0.7349],
                [0.5241, -0.1246, -0.4751]])
label = np.array([[0, 1, 1],
                  [0, 0, 1],
                  [1, 0, 1]])

pred = torch.from_numpy(pred).float()
label = torch.from_numpy(label).float()

## 通过BCEWithLogitsLoss直接计算输入值(pick)
crition1 = torch.nn.BCEWithLogitsLoss()
loss1 = crition1(pred, label)
print(loss1)

crition2 = torch.nn.MultiLabelSoftMarginLoss()
loss2 = crition2(pred, label)
print(loss2)

##  通过BCELoss计算sigmoid处理后的值
crition3 = torch.nn.BCELoss()
loss3 = crition3(torch.sigmoid(pred), label)
print(loss3)

标签:nn,pred,torch,label,np,pytorch,API,print
来源: https://www.cnblogs.com/mercurysun/p/16334222.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有