ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

1

2022-09-04 12:31:51  阅读:176  来源: 互联网

标签: acc macro content train test data


def read_imdb(data_dir, filename):
    data, labels = [], []
    folder_name = os.path.join(data_dir, filename)
    with open(folder_name, 'r',encoding="utf-8") as f:
        json_data = json.loads(f.readline())
        for i in json_data:
            if i["label"]=="neural":
                labels.append(0)
            elif i["label"]=="happy":
                labels.append(1)
            elif i["label"]=="angry":
                labels.append(2)
            elif i["label"]=="sad":
                labels.append(3)
            elif i["label"]=="fear":
                labels.append(4)
            elif i["label"]=="surprise":
                labels.append(5)
            i["content"] = re.sub(r'\/\/\@.*?(\:|\:)', "", i['content']) # 清除@信息
            i['content'] = re.sub(r'\#.*?\#', "", i['content']) # 清除#信息
            i['content'] = re.sub(r'\【.*?\】', "", i['content']) # 清除【标签】
            i['content'] = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', "", i['content'], flags=re.MULTILINE) # 清除链接信息
            data.append(i['content'])
    return data, labels

def data_segmentation(data,labels):
    ssplite = ''
    pdata = []
    for d,l in zip(data,labels):
        content_to_str = ' '.join( jieba.cut(d,cut_all=False)) 
        content_to_str = re.sub("[^\u4e00-\u9fa5^a-z^A-Z^0-9^\s]","", content_to_str) # 去除非中英文字、数字的所有字符
        for i in range(6):
            content_to_str = content_to_str.replace('  ',' ') # 去除多余空格
        content_to_str = content_to_str.strip() # 去除两边空格
        pdata.append([content_to_str.split(' '),l])
        content_to_str += '\r\n'
        ssplite += content_to_str
    return pdata, ssplite

def create_dictionaries(p_model):
    g_dict = Dictionary()
    g_dict.doc2bow(p_model.wv.index_to_key, allow_update=True) # 每一句话进行词频统计
    w2indx = {v: k  for k, v in g_dict.items()}  # 定义word to index词库
    id2vec = {w2indx.get(word): model.wv.__getitem__(word) for word in w2indx.keys()}  # 定义index to vector词库, 词语的embedding
    return w2indx, id2vec

def get_tokenized_imdb(data):
    for word_list, label in data:
        temp = []
        for word in word_list:
            if(word in word_id_dic.keys()):
                temp.append(int(word_id_dic[word]))
            else:
                temp.append(0)
        yield [temp,label]
        
def preprocess_imdb(data):
    max_l = 30  
    def pad(x):
        return x[:max_l] if len(x) > max_l else x + [1] * (max_l - len(x))
    features = torch.tensor([pad(content[0]) for content in data])
    labels = torch.tensor([score for _, score in data])
    return features, labels

class BiRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, num_hiddens,num_layers, **kwargs):
        super(BiRNN, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.encoder = nn.LSTM(embed_size, num_hiddens, num_layers=num_layers,bidirectional=True)
        self.decoder = nn.Linear(4 * num_hiddens, 6)
    def forward(self, inputs):
        embeddings = self.embedding(inputs.T)
        self.encoder.flatten_parameters()
        outputs, _ = self.encoder(embeddings)
        encoding = torch.cat((outputs[0], outputs[-1]), dim=1)
        outs = self.decoder(encoding)
        return outs

def train_epoch(net, data_loader,optimizer, device):
    net.train() #指定当前为训练模式
    l = 0 #记录Loss
    batch_count = 0
    train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
    for x, y in data_loader:
        x = x.to(device)
        y = y.to(device)
        y_hat = net(x) #使用模型计算出预测结果
        optimizer.zero_grad()  #将当前梯度清零
        l = loss(y_hat, y)#计算损失
        l.backward() #进行反向传播
        optimizer.step() #更新权重参数
        train_l_sum += l.cpu().item()
        train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
        n += y.shape[0]
        batch_count += 1
        
    loss_ =  train_l_sum / batch_count  #计算平均loss与准确率
    acc =  train_acc_sum / n
    return loss_, acc

def test_epoch(net, data_loader, device):
    net.eval() #指定当前模式为测试模式
    batch_count = 0
    l = 0
    pred=[]
    real=[]
    test_l_sum, test_acc_sum, n = 0.0, 0.0, 0
    with torch.no_grad(): #指定不进行梯度变化
         for x, y in data_loader:
            x = x.to(device)
            y = y.to(device)
            y_hat = net(x) #使用模型计算出预测结果
            l = loss(y_hat, y)#计算损失
            test_l_sum += l.cpu().item()
            test_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
            
            pred.extend(y_hat.argmax(dim=1).tolist())
            real.extend([int(e) for e in y])
            
    macro_F1=f1_score(real,pred,average='macro')     # f1得分 
    macro_R=recall_score(real,pred,average='macro') # 宏召回率
    macro_P = precision_score(real, pred, average='macro') # 宏精确率
    loss_ =  test_l_sum / batch_count  #计算平均loss与准确率
    acc =  test_acc_sum / n
    return loss_,acc,(macro_F1,macro_R,macro_P)

for epoch in range(num_epochs):
    epochstart = time.perf_counter ()  #每一个epoch的开始时间
    train_loss, train_acc = train_epoch(net.to(device),train_iter,optimizer, device)
    test_loss, test_acc,macro = test_epoch(net.to(device),test_iter, device=device)
    elapsed = (time.perf_counter () - epochstart)  #每一个epoch的结束时间    
    train_loss_list.append(train_loss)
    train_acc_list.append(train_acc)
    test_loss_list.append(test_loss)
    test_acc_list.append(test_acc)
    macro_F1_list.append(macro[0])
    macro_R_list.append(macro[1])
    macro_P_list.append(macro[2])
    time_list.append(elapsed)
    if((epoch+1)%5 == 0):
        print('epoch %d, train_loss %.3f,test_loss %.3f,train_acc %.3f,test_acc %.3f,Time used %.3fs,macro_F1 %.3f,macro_R %.3f,macro_P %.3f'%
              (epoch+1, train_loss,test_loss,train_acc,test_acc,elapsed,macro[0],macro[1],macro[2] ))

  

标签:,acc,macro,content,train,test,data
来源: https://www.cnblogs.com/chrysanthemum/p/16609961.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有