标签:loss obj self YOLO num fg 置信度 Focal
将YOLOv3及以上的网络中的BCE loss更改为Focal loss
loss函数分为三部分,位置损失、置信度损失、类别损失,此处只需要将置信度损失更换为Focal loss,具体原理请仔细理解置信度损失的含义。
YOLOX链接:https://link.zhihu.com/?target=https%3A//github.com/Megvii-BaseDetection/YOLOX
1 找到置信度预测损失计算位置loss_obj,并进行替换(位置在386-405行左右)
loss_iou:定位损失;loss_obj:置信度预测损失;loss_cls:预测损失
loss_iou = (
self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)
).sum() / num_fg
#loss_obj = (
# self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)
#).sum() / num_fg
loss_obj = (
self.focal_loss(obj_preds.sigmoid().view(-1, 1), obj_targets)
).sum() / num_fg
loss_cls = (
self.bcewithlog_loss(
cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets
)
).sum() / num_fg
2 创建focal_loss方法,放到def get_l1_target(…)之前即可,代码如下:
def focal_loss(self, pred, gt):
pos_inds = gt.eq(1).float()
neg_inds = gt.eq(0).float()
pos_loss = torch.log(pred+1e-5) * torch.pow(1 - pred, 2) * pos_inds * 0.75
neg_loss = torch.log(1 - pred+1e-5) * torch.pow(pred, 2) * neg_inds * 0.25
loss = -(pos_loss + neg_loss)
return loss
标签:loss,obj,self,YOLO,num,fg,置信度,Focal 来源: https://blog.csdn.net/weixin_43850171/article/details/123140328
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。