ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

知识图到文本的生成——陆

2021-12-05 17:30:18  阅读:172  来源: 互联网

标签:nn 图到 fields self args 生成 hsz 文本 ds


2021SC@SDUSC

mkiters函数也是dataset类中的一个重要的类函数,我的队友已经在她的博客中详细分析过这个函数,此处不再赘述。

  def mktestset(self, args):
    path = args.path.replace("train",'test')
    fields=self.fields
    ds = data.TabularDataset(path=path, format='tsv',fields=fields)
    ds.fields["rawent"] = data.RawField()
    for x in ds:
      x.rawent = x.ent.split(" ; ")
      x.ent = self.vec_ents(x.ent,self.ENT)
      x.rel = self.mkGraphs(x.rel,len(x.ent[1]))
      if args.sparse:
        x.rel = (self.adjToSparse(x.rel[0]),x.rel[1])
      x.tgt = x.out
      x.out = [y.split("_")[0]+">" if "_" in y else y for y in x.out]
      x.sordertgt = torch.LongTensor([int(y)+3 for y in x.sorder.split(" ")])
      x.sorder = [[int(z) for z in y.strip().split(" ")] for y in x.sorder.split("-1")[:-1]]
    ds.fields["tgt"] = self.TGT
    ds.fields["rawent"] = data.RawField()
    ds.fields["sordertgt"] = data.RawField()
    dat_iter = data.Iterator(ds,1,device=args.device,sort_key=lambda x:len(x.src), train=False, sort=False)
    return dat_iter

mktestset函数是dataset类中一个用来形成测试集的函数,对数据集进行遍历之后,返回一个迭代器。其余的函数都是对数据集进行一些修饰工作,不再一一展开详细分析。

让我们回到最最开始的地方, 继续分析train.py程序。之前我们详细分析了dataset类,而pargs.py由我的队友来着重分析,那么我们就继续看:

m = model(args)

我们开始分析model类。

class model(nn.Module):

首先我们看model类的init函数。

  def __init__(self,args):
    super().__init__()
    self.args = args
    cattimes = 3 if args.title else 2
    self.emb = nn.Embedding(args.ntoks,args.hsz)
    self.lstm = nn.LSTMCell(args.hsz*cattimes,args.hsz)
    self.out = nn.Linear(args.hsz*cattimes,args.tgttoks)
    self.le = list_encode(args)
    self.entout = nn.Linear(args.hsz,1)
    self.switch = nn.Linear(args.hsz*cattimes,1)

这个model类继承了torch.nn,其中的参数都是调用了torch.nn中的函数。cattimes是分类的次数,如果有标题,就设置为3,如果没有标题,就为2。emb为用args.ntoks和args.hsz组成的矩阵(args.ntoks是输出的vocab长度,会在pargs.py的代码分析中详细介绍)。lstm是用hsz和分类次数乘积作为构建LSTM中的一个Cell的输入特征维度,hsz作为构建LSTM中的一个Cell的隐状态的维度,torch.nn中的LSTM和LSTMCell的操作如下图:

 

 out、entout、switch都是调用了nn.Linear()函数,其中的参数都是指维度,对二维变量进行线性变换,如图所示。

    self.attn = MultiHeadAttention(args.hsz,args.hsz,args.hsz,h=4,dropout_p=args.drop)
    self.mattn = MatrixAttn(args.hsz*cattimes,args.hsz)
    self.graph = (args.model in ['graph','gat','gtrans'])
    print(args.model)

MultiHeadAttention()是attention.py中的类,继承Module,这里的操作是返回一个连接后的4*4维度的attn。MatrixAttn()也是attention.py中的类,继承Module,这里是对hsz和分类次数的乘积和hsz作线性变换(如上图所示)。graph则是模型生成的图,然后在终端打印出来。

标签:nn,图到,fields,self,args,生成,hsz,文本,ds
来源: https://blog.csdn.net/qq_50729659/article/details/121542713

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有