ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

torch.nn.Embedding(num_embeddings,embedding_dim)实现文本转换词向量,并完成文本情感分类任务

2021-10-25 18:01:55  阅读:200  来源: 互联网

标签:__ dim embeddings max self len return path 文本


1、处理数据集

 1 import torch
 2 import os
 3 import re
 4 from torch.utils.data import Dataset, DataLoader
 5 
 6 
 7 dataset_path = r'C:\Users\ci21615\Downloads\aclImdb_v1\aclImdb'
 8 
 9 
10 def tokenize(text):
11     """
12     分词,处理原始文本
13     :param text:
14     :return:
15     """
16     fileters = ['!', '"', '#', '$', '%', '&', '\(', '\)', '\*', '\+', ',', '-', '\.', '/', ':', ';', '<', '=', '>', '\?', '@'
17         , '\[', '\\', '\]', '^', '_', '`', '\{', '\|', '\}', '~', '\t', '\n', '\x97', '\x96', '”', '“', ]
18     text = re.sub("<.*?>", " ", text, flags=re.S)
19     text = re.sub("|".join(fileters), " ", text, flags=re.S)
20     return [i.strip() for i in text.split()]
21 
22 
23 class ImdbDataset(Dataset):
24     """
25     准备数据集
26     """
27     def __init__(self, mode):
28         super(ImdbDataset, self).__init__()
29         if mode == 'train':
30             text_path = [os.path.join(dataset_path, i) for i in ['train/neg', 'train/pos']]
31         else:
32             text_path = [os.path.join(dataset_path, i) for i in ['test/neg', 'test/pos']]
33         self.total_file_path_list = []
34         for i in text_path:
35             self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])
36 
37     def __getitem__(self, item):
38         cur_path = self.total_file_path_list[item]
39         cur_filename = os.path.basename(cur_path)
40         # 获取标签
41         label = int(cur_filename.split('_')[-1].split('.')[0]) - 1
42         text = tokenize(open(cur_path).read().strip())
43         return label, text
44 
45     def __len__(self):
46         return len(self.total_file_path_list)
47 
48 
49 if __name__ == '__main__':
50     imdb_dataset = ImdbDataset('train')
51     print(imdb_dataset[0])
View Code

当前数据集处理后样式:

2、自定义dataloader中的collate_fn

 1 def collate_fn(batch):
 2     """
 3     batch是list,其中是一个一个元组,每个元组是dataset中__getitem__的结果
 4     :param batch:
 5     :return:
 6     """
 7     batch = list(zip(*batch))
 8     labels = torch.tensor(batch[0], dtype=torch.int32)
 9     texts = batch[1]
10     del batch
11     return labels, texts
12 
13 
14 dataset = ImdbDataset('train')
15 dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
16 
17 
18 if __name__ == '__main__':
19     for index, (label, text) in enumerate(dataloader):
20         print(index)
21         print(label)
22         print(text)
23         break
View Code

当前结果:

3、文本序列化

每个词都需要先给定一个初始的数字,再对该数字转换成向量

  1 class Word2Sequence():
  2     """
  3     文本序列化
  4     思路分析:
  5     1、对所有句子进行分词
  6     2、词语存入字典,根据次数对词语进行过滤,并统计次数
  7     3、实现文本转数字序列的方法
  8     4、实现数字序列转文本方法
  9     """
 10     UNK_TAG = 'UNK'
 11     PAD_TAG = 'PAD'
 12     UNK = 0
 13     PAD = 1
 14 
 15     def __init__(self):
 16         self.dict = {
 17             self.UNK_TAG: self.UNK,
 18             self.PAD_TAG: self.PAD
 19         }
 20         self.fited = False
 21 
 22     def to_index(self, word):
 23         """
 24         文本转换成数字
 25         :param word:
 26         :return:
 27         """
 28         assert self.fited == True
 29         return self.dict.get(word, self.UNK)
 30 
 31     def to_word(self, index):
 32         """
 33         数字转文本
 34         :param index:
 35         :return:
 36         """
 37         assert self.fited
 38         if index in self.inversed_dict:
 39             return self.inversed_dict[index]
 40         return self.UNK_TAG
 41 
 42     def __len__(self):
 43         return len(self.dict)
 44 
 45     def fit(self, sentences, min_count=1, max_count=None, max_feature=None):
 46         """
 47         :param sentences:[[word1,word2,word3],[word1,word3,wordn..],...]
 48         :param min_count: 最小出现的次数
 49         :param max_count: 最大出现的次数
 50         :param max_feature: 总词语的最大数量
 51         :return:
 52         """
 53         count = {}
 54         # 单词出现的次数
 55         for sentence in sentences:
 56             for a in sentence:
 57                 if a not in count:
 58                     count[a] = 0
 59                 count[a] += 1
 60         # 根据单词数量进行处理,即可以过滤频率小的单词
 61         if min_count is not None:
 62             count = {k:v for k, v in count.items() if v >= min_count}
 63         if max_count is not None:
 64             count = {k:v for k, v in count.items() if v <= max_count}
 65         # 限制最大的数量
 66         # 每个数字对应的初始值就是加入dict时dict的大小
 67         if isinstance(max_feature, int):
 68             count = sorted(list(count.items()), key=lambda x: x[1])
 69             if max_feature is not None and len(count) > max_feature:
 70                 count = count[-int(max_feature):]
 71             for w, _ in count:
 72                 self.dict[w] = len(self.dict)
 73         else:
 74             for w in sorted(count.keys()):
 75                 self.dict[w] = len(self.dict)
 76         self.fited = True
 77         self.inversed_dict = dict(zip(self.dict.values(), self.dict.keys()))
 78 
 79     def transform(self, sentence, max_len=None):
 80         """
 81         实现吧句子转化为数组(向量)
 82         :param sentence:
 83         :param max_len:
 84         :return:
 85         """
 86         assert self.fited
 87         if max_len is not None:
 88             r = [self.PAD] * max_len
 89         else:
 90             r = [self.PAD] * len(sentence)
 91         if max_len is not None and len(sentence) > max_len:
 92             sentence = sentence[:max_len]
 93         for index, word in enumerate(sentence):
 94             r[index] = self.to_index(word)
 95         return np.array(r, dtype=np.int64)
 96 
 97     def inverse_transform(self, indices):
 98         """
 99         实现从数组 转化为文字
100         :param indices: [1,2,3....]
101         :return:[word1,word2.....]
102         """
103         sentence = []
104         for i in indices:
105             word = self.to_word(i)
106             sentence.append(word)
107         return sentence
108 
109 
110 
111 if __name__ == '__main__':
112     w2s = Word2Sequence()
113     w2s.fit([
114         ['这', '是', '什', '么'],
115         ['那', '是', '神', '么']
116     ])
117     print(w2s.dict)
118     print(w2s.fited)
119     print(w2s.transform(['神', '么', '这']))
120     print(w2s.transform(['神么这'], max_len=10))
View Code

结果:

4、对Imdb数据构建字典,每个词对应一个数字

 1 # 实现对IMDB数据的处理和保存
 2 def fit_save_word_sequence():
 3     """
 4     从数据集构建字典
 5     :return:
 6     """
 7     ws = Word2Sequence()
 8     train_path = [os.path.join(dataset_path, i) for i in ['train/neg', 'train/pos']]
 9     total_file_path_list = []
10     for i in train_path:
11         total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])
12     for cur_path in tqdm(total_file_path_list, desc='fitting'):
13         sentence = open(cur_path, encoding='utf-8').read().strip()
14         res = tokenize(sentence)
15         ws.fit([res])
16     # 对wordSequesnce进行保存
17     print(ws.dict)
18     print(len(ws))
19     pickle.dump(ws, open('./model/ws.pkl', 'wb'))
20 
21 
22 if __name__ == '__main__':
23     fit_save_word_sequence()
View Code

5、对每一段文本转换成向量,可指定max_len维度

 1 def get_dataloader(mode='train'):
 2     """
 3     获取数据集,转换成词向量后的数据集
 4     :param mode:
 5     :return:
 6     """
 7     # 导入词典
 8     ws = pickle.load(open('./model/ws.pkl', 'rb'))
 9     print(len(ws))
10     # 自定义collate_fn函数
11     def collate_fn(batch):
12         """
13         batch是list,其中是一个一个元组,每个元组是dataset中__getitem__的结果
14         :param batch:
15         :return:
16         """
17         max_len = 500
18         batch = list(zip(*batch))
19         labels = torch.tensor(batch[0], dtype=torch.int32)
20         texts = batch[1]
21         # 获取每个文本的长度
22         lengths = [len(i) if len(i) < max_len else max_len for i in texts]
23         # 每一段文本句子都转换成了max_len维度的向量,即500维的向量
24         temp = [ws.transform(i, max_len) for i in texts]
25         texts = torch.tensor(temp)
26 
27         del batch
28         return labels, texts, lengths
29     dataset = ImdbDataset(mode)
30     dataloader = DataLoader(dataset=dataset, batch_size=20, shuffle=True, collate_fn=collate_fn)
31     return dataloader
32 
33 
34 if __name__ == '__main__':
35     for index, (label, texts, length) in enumerate(get_dataloader()):
36         print(index)
37         print(label)
38         print(texts)
39         print(length)
View Code

 报错问题:

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)

说白了就是num_embeddings(词典的词个数)不够大,为什么不够呢

按道理说,我们词嵌入的时候字典从0,1,…………n,映射我们所有的词(或者字)

num_embeddings = n,是够用的,但是我们考虑pad,pad默认一般是0,所以我们会重新处理一下映射字典1,2…………n+1

这时候 num_embeddings = n+1才够映射

所以+1就够了

然后就不会报错了

 

 

 

 

 

 

 

 

 

参考:

https://blog.csdn.net/weixin_36488653/article/details/118485063

 

标签:__,dim,embeddings,max,self,len,return,path,文本
来源: https://www.cnblogs.com/luyizhou/p/15459729.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有