ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

CAIL2021-阅读理解任务-模型模块

2022-06-09 14:33:46  阅读:159  来源: 互联网

标签:loss None end CAIL2021 outputs 模型 start 模块 logits


代码地址:https://github.com/china-ai-law-challenge/CAIL2021/blob/main/ydlj/baseline/model.py

import torch
from torch.nn import CrossEntropyLoss, BCELoss
from torch import nn


class MultiSpanQA(nn.Module):
    def __init__(self, pretrain_model):
        super(MultiSpanQA, self).__init__()
        self.pretrain_model = pretrain_model
        # represent start logits and end logits respectively
        self.qa_outputs = nn.Linear(pretrain_model.config.hidden_size, 2)

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            start_labels=None,  # size: (batch_size, max_seq_length, 1)
            end_labels=None,
    ):
        outputs = self.pretrain_model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
        sequence_output = outputs[0]
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        outputs = (start_logits, end_logits,) + outputs[2:]
        if start_labels is not None and end_labels is not None:
            loss_fct = BCELoss(reduction="mean")
            start_loss = loss_fct(torch.sigmoid(start_logits), start_labels)
            end_loss = loss_fct(torch.sigmoid(end_logits), end_labels)
            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss,) + outputs
        return outputs

模型结构挺简单,就是对每一个token进行二分类,判断是不是答案的起始位置和终止位置。注意这里使用的是BCELoss(),需要先对输出进行sigmoid()处理。

标签:loss,None,end,CAIL2021,outputs,模型,start,模块,logits
来源: https://www.cnblogs.com/xiximayou/p/16359231.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有