ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

使用注意力机制建模 - 标准化日期格式

2020-12-17 23:29:29  阅读:251  来源: 互联网

标签:vocab dense 标准化 建模 machine bidirectional 格式 lstm accuracy


文章目录


参考 基于深度学习的自然语言处理

本文使用attention机制的模型,将各种格式的日期转化成标准格式的日期

1. 概述

  • LSTM、GRU 减少了梯度消失的问题,但是对于复杂依赖结构的长句子,梯度消失仍然存在
  • 注意力机制能同时看见句子中的每个位置,并赋予每个位置不同的权重(注意力),且可以并行计算

2. 数据

  • 生成日期数据
from faker import Faker
from babel.dates import format_date
import random
fake = Faker()
fake.seed(123)
random.seed(321)

# 各种日期格式
FORMATS = ['short',
           'medium',
           'long',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'd MMM YYY',
           'd MMMM YYY',
           'dd MMM YYY',
           'd MMM, YYY',
           'd MMMM, YYY',
           'dd, MMM YYY',
           'd MM YY',
           'd MMMM YYY',
           'MMMM d YYY',
           'MMMM d, YYY',
           'dd.MM.YY']
  • 生成日期数据:随机格式(X),标准格式(Y)
def load_date():
    # 加载一些日期数据
    dt = fake.date_object() # 随机一个日期
    human_readable = format_date(dt, format=random.choice(FORMATS),
                                 locale='en_US')
    # 使用随机选取的格式,生成日期
    human_readable = human_readable.lower().replace(',','')
    machine_readable = dt.isoformat() # 标准格式
    return human_readable, machine_readable, dt

test_date = load_date()

输出:

  • 建立字典,以及映射关系(字符 :idx)
from tqdm import tqdm # 显示进度条
def load_dateset(num_of_data):
    human_vocab = set()
    machine_vocab = set()
    dataset = []
    Tx = 30 # 日期最大长度
    for i in tqdm(range(num_of_data)):
        h, m, _ = load_date()
        if h is not None:
            dataset.append((h, m))
            human_vocab.update(tuple(h))
            machine_vocab.update(tuple(m))
    human = dict(zip(sorted(human_vocab)+['<unk>', '<pad>'],
                     list(range(len(human_vocab)+2))))
    # x 字符:idx 的映射
    inv_machine = dict(enumerate(sorted(machine_vocab)))
    # idx : y 字符
    machine = {v : k for k, v in inv_machine.items()}
    # y 字符 : idx
    return dataset, human, machine, inv_machine

m = 10000 # 样本个数
dataset, human_vocab, machine_vocab, inv_machine_vocab = load_dateset(m)
  • 日期(char序列)转 ids 序列,并且 pad / 截断
import numpy as np
from keras.utils import to_categorical

def string_to_int(string, length, vocab):
    string = string.lower().replace(',','')
    if len(string) > length: # 长了,截断
        string = string[:length]
    rep = list(map(lambda x : vocab.get(x, '<unk>'), string))
    # 对string里每个char 使用 匿名函数 获取映射的id,没有的话,使用unk的id,map返回迭代器,转成list
    if len(string) < length:
        rep += [vocab['<pad>']]*(length-len(string))
        # 长度不够,加上 pad 的 id
    return rep # 返回 [ids,...]
  • 根据 ids 序列生成 one_hot 矩阵
def process_data(dataset, human_vocab, machine_vocab, Tx, Ty):
    X,Y = zip(*dataset)
    print("处理前 X:{}".format(X))
    print("处理前 Y:{}".format(Y))
    X = np.array([string_to_int(date, Tx, human_vocab) for date in X])
    Y = [string_to_int(date, Ty, machine_vocab) for date in Y]
    print("处理后 X的shape:{}".format(X.shape))
    print("处理后 Y: {}".format(Y))

    Xoh = np.array(list(map(lambda x : to_categorical(x, num_classes=len(human_vocab)), X)))
    Yoh = np.array(list(map(lambda x : to_categorical(x, num_classes=len(machine_vocab)), Y)))
    return X, np.array(Y), Xoh, Yoh
Tx = 30 # 输入长度
Ty = 10 # 输出长度
X, Y, Xoh, Yoh = process_data(dataset, human_vocab, machine_vocab, Tx, Ty)


检查生成的 one_hot 编码矩阵维度

print(X.shape)
print(Y.shape)
print(Xoh.shape)
print(Yoh.shape)

输出:

(10000, 30)
(10000, 10)
(10000, 30, 37)
(10000, 10, 11)

3. 模型

  • softmax 激活函数,求注意力权重
from keras import backend as K
def softmax(x, axis=1):
    ndim = K.ndim(x)
    if ndim == 2:
        return K.softmax(x)
    elif ndim > 2:
        e = K.exp(x - K.max(x, axis=axis, keepdims=True))
        s = K.sum(e, axis=axis, keepdims=True)
        return e/s
    else:
        raise ValueError('维度不对,不能是1维')
  • 模型组件
from keras.layers import RepeatVector, LSTM, Concatenate, \
    Dense, Activation, Dot, Input, Bidirectional

repeator = RepeatVector(Tx) # 重复 Tx 次
# 重复器
# Input shape:
#     2D tensor of shape `(num_samples, features)`.
#
# Output shape:
#     3D tensor of shape `(num_samples, n, features)`.
concator = Concatenate(axis=-1) # 拼接器
densor1 = Dense(10, activation='tanh') # FC
densor2 = Dense(1, activation='relu') # FC
activator = Activation(softmax, name='attention_weights') # 计算注意力权重
dotor = Dot(axes=1) # 加权
  • 模型
def one_step_attention(h, s_prev):
    s_prev = repeator(s_prev) # 将前一个输出状态重复 Tx 次
    concat = concator([h, s_prev]) # 与 全部句子状态 拼接
    e = densor1(concat) # 经过 FC
    energies = densor2(e) # 经过FC
    alphas = activator(energies) # 得到注意力权重
    context = dotor([alphas, h]) # 跟原句子状态做attention
    return context # 得到上下文向量,后序输入到解码器

# 解码器,是一个单向LSTM
n_h = 32
n_s = 64
post_activation_LSTM_cell = LSTM(n_s, return_state=True) # 单向LSTM
output_layer = Dense(len(machine_vocab), activation=softmax) # FC 输出预测值

from keras.models import Model
def model(Tx, Ty, n_h, n_s, human_vocab_size, machine_vocab_size):
    X = Input(shape=(Tx,human_vocab_size), name='input_first')
    s0 = Input(shape=(n_s,),name='s0')
    c0 = Input(shape=(n_s,),name='c0')
    s = s0
    c = c0
    outputs = []
    h = Bidirectional(LSTM(n_h, return_sequences=True))(X) # 编码器得到整个序列的状态
    for t in range(Ty): # 解码器 推理
        context = one_step_attention(h, s) # attention 得到上下文向量
        s, _, c = post_activation_LSTM_cell(context, initial_state=[s,c])
        out = output_layer(s) # FC 输出预测
        outputs.append(out)
    model = Model(inputs=[X,s0,c0], outputs=outputs)
    return model
    
model = model(Tx,Ty,n_h,n_s,len(human_vocab), len(machine_vocab))
model.summary()

from keras.utils import plot_model
plot_model(model, to_file='model.png',show_shapes=True,rankdir='TB')

输出:

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_first (InputLayer)        [(None, 30, 37)]     0                                            
__________________________________________________________________________________________________
s0 (InputLayer)                 [(None, 64)]         0                                            
__________________________________________________________________________________________________
bidirectional (Bidirectional)   (None, 30, 64)       17920       input_first[0][0]                
__________________________________________________________________________________________________
repeat_vector (RepeatVector)    (None, 30, 64)       0           s0[0][0]                         
                                                                 lstm[0][0]                       
                                                                 lstm[1][0]                       
                                                                 lstm[2][0]                       
                                                                 lstm[3][0]                       
                                                                 lstm[4][0]                       
                                                                 lstm[5][0]                       
                                                                 lstm[6][0]                       
                                                                 lstm[7][0]                       
                                                                 lstm[8][0]                       
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 30, 128)      0           bidirectional[0][0]              
                                                                 repeat_vector[0][0]              
                                                                 bidirectional[0][0]              
                                                                 repeat_vector[1][0]              
                                                                 bidirectional[0][0]              
                                                                 repeat_vector[2][0]              
                                                                 bidirectional[0][0]              
                                                                 repeat_vector[3][0]              
                                                                 bidirectional[0][0]              
                                                                 repeat_vector[4][0]              
                                                                 bidirectional[0][0]              
                                                                 repeat_vector[5][0]              
                                                                 bidirectional[0][0]              
                                                                 repeat_vector[6][0]              
                                                                 bidirectional[0][0]              
                                                                 repeat_vector[7][0]              
                                                                 bidirectional[0][0]              
                                                                 repeat_vector[8][0]              
                                                                 bidirectional[0][0]              
                                                                 repeat_vector[9][0]              
__________________________________________________________________________________________________
dense (Dense)                   (None, 30, 10)       1290        concatenate[0][0]                
                                                                 concatenate[1][0]                
                                                                 concatenate[2][0]                
                                                                 concatenate[3][0]                
                                                                 concatenate[4][0]                
                                                                 concatenate[5][0]                
                                                                 concatenate[6][0]                
                                                                 concatenate[7][0]                
                                                                 concatenate[8][0]                
                                                                 concatenate[9][0]                
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 30, 1)        11          dense[0][0]                      
                                                                 dense[1][0]                      
                                                                 dense[2][0]                      
                                                                 dense[3][0]                      
                                                                 dense[4][0]                      
                                                                 dense[5][0]                      
                                                                 dense[6][0]                      
                                                                 dense[7][0]                      
                                                                 dense[8][0]                      
                                                                 dense[9][0]                      
__________________________________________________________________________________________________
attention_weights (Activation)  (None, 30, 1)        0           dense_1[0][0]                    
                                                                 dense_1[1][0]                    
                                                                 dense_1[2][0]                    
                                                                 dense_1[3][0]                    
                                                                 dense_1[4][0]                    
                                                                 dense_1[5][0]                    
                                                                 dense_1[6][0]                    
                                                                 dense_1[7][0]                    
                                                                 dense_1[8][0]                    
                                                                 dense_1[9][0]                    
__________________________________________________________________________________________________
dot (Dot)                       (None, 1, 64)        0           attention_weights[0][0]          
                                                                 bidirectional[0][0]              
                                                                 attention_weights[1][0]          
                                                                 bidirectional[0][0]              
                                                                 attention_weights[2][0]          
                                                                 bidirectional[0][0]              
                                                                 attention_weights[3][0]          
                                                                 bidirectional[0][0]              
                                                                 attention_weights[4][0]          
                                                                 bidirectional[0][0]              
                                                                 attention_weights[5][0]          
                                                                 bidirectional[0][0]              
                                                                 attention_weights[6][0]          
                                                                 bidirectional[0][0]              
                                                                 attention_weights[7][0]          
                                                                 bidirectional[0][0]              
                                                                 attention_weights[8][0]          
                                                                 bidirectional[0][0]              
                                                                 attention_weights[9][0]          
                                                                 bidirectional[0][0]              
__________________________________________________________________________________________________
c0 (InputLayer)                 [(None, 64)]         0                                            
__________________________________________________________________________________________________
lstm (LSTM)                     [(None, 64), (None,  33024       dot[0][0]                        
                                                                 s0[0][0]                         
                                                                 c0[0][0]                         
                                                                 dot[1][0]                        
                                                                 lstm[0][0]                       
                                                                 lstm[0][2]                       
                                                                 dot[2][0]                        
                                                                 lstm[1][0]                       
                                                                 lstm[1][2]                       
                                                                 dot[3][0]                        
                                                                 lstm[2][0]                       
                                                                 lstm[2][2]                       
                                                                 dot[4][0]                        
                                                                 lstm[3][0]                       
                                                                 lstm[3][2]                       
                                                                 dot[5][0]                        
                                                                 lstm[4][0]                       
                                                                 lstm[4][2]                       
                                                                 dot[6][0]                        
                                                                 lstm[5][0]                       
                                                                 lstm[5][2]                       
                                                                 dot[7][0]                        
                                                                 lstm[6][0]                       
                                                                 lstm[6][2]                       
                                                                 dot[8][0]                        
                                                                 lstm[7][0]                       
                                                                 lstm[7][2]                       
                                                                 dot[9][0]                        
                                                                 lstm[8][0]                       
                                                                 lstm[8][2]                       
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 11)           715         lstm[0][0]                       
                                                                 lstm[1][0]                       
                                                                 lstm[2][0]                       
                                                                 lstm[3][0]                       
                                                                 lstm[4][0]                       
                                                                 lstm[5][0]                       
                                                                 lstm[6][0]                       
                                                                 lstm[7][0]                       
                                                                 lstm[8][0]                       
                                                                 lstm[9][0]                       
==================================================================================================
Total params: 52,960
Trainable params: 52,960
Non-trainable params: 0
________________________________________________________________________________________________

4. 训练

from keras.optimizers import Adam
# 优化器
opt = Adam(learning_rate=0.005, decay=0.01)
# 配置模型
model.compile(optimizer=opt, loss='categorical_crossentropy',
              metrics=['accuracy'])

# 初始化 解码器状态
s0 = np.zeros((m, n_s))
c0 = np.zeros((m, n_s))
outputs = list(Yoh.swapaxes(0, 1))
# Yoh shape 10000*10*11,调换0,1轴,为10*10000*11
# outputs list,长度 10, 每个里面是array 10000*11

history = model.fit([Xoh, s0, c0], outputs,
                    epochs=10, batch_size=128,
                    validation_split=0.1)
  • 绘制 loss 和 各位置的准确率
from matplotlib import pyplot as plt
import pandas as pd
his = pd.DataFrame(history.history)
print(his.columns)
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.plot(loss, label='train Loss')
plt.plot(val_loss, label='valid Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid()
plt.show()

# 列 具体的名字根据运行次数,会有变化
col_train_acc = (
 'dense_7_accuracy', 'dense_7_1_accuracy', 'dense_7_2_accuracy',
 'dense_7_3_accuracy', 'dense_7_4_accuracy', 'dense_7_5_accuracy',
 'dense_7_6_accuracy', 'dense_7_7_accuracy', 'dense_7_8_accuracy',
 'dense_7_9_accuracy')
col_test_acc = (
 'val_dense_7_accuracy', 'val_dense_7_1_accuracy',
 'val_dense_7_2_accuracy', 'val_dense_7_3_accuracy',
 'val_dense_7_4_accuracy', 'val_dense_7_5_accuracy',
 'val_dense_7_6_accuracy', 'val_dense_7_7_accuracy',
 'val_dense_7_8_accuracy', 'val_dense_7_9_accuracy')
train_acc = pd.DataFrame(history.history[c] for c in col_train_acc)
test_acc = pd.DataFrame(history.history[c] for c in col_test_acc)

train_acc.plot()
plt.title('Training Accuracy on pos')
plt.legend()
plt.grid()
plt.show()

test_acc.plot()
plt.title('Validation Accuracy on pos')
plt.legend()
plt.grid()
plt.show()

5. 测试

s0 = np.zeros((1, n_s))
c0 = np.zeros((1, n_s))
test_data,_,_,_ = load_dateset(10)
for x,y in test_data:
    print(x + " ==> " +y)
for x,_ in test_data:
    source = string_to_int(x, Tx, human_vocab)
    source = np.array(list(map(lambda a : to_categorical(a, num_classes=len(human_vocab)), source)))
    source = source[np.newaxis, :]
    pred = model.predict([source, s0, c0])
    pred = np.argmax(pred, axis=-1)
    output = [inv_machine_vocab[int(i)] for i in pred]
    print('source:',x)
    print('output:',''.join(output))

输出:

18 april 2014 ==> 2014-04-18
saturday august 22 1998 ==> 1998-08-22
october 22 1995 ==> 1995-10-22
thursday february 29 1996 ==> 1996-02-29
wednesday october 17 1979 ==> 1979-10-17
7 12 73 ==> 1973-12-07
9/30/01 ==> 2001-09-30
22 may 2001 ==> 2001-05-22
7 march 1979 ==> 1979-03-07
19 feb 2013 ==> 2013-02-19

预测10个,错误了4个,日期字符不完全正确

source: 18 april 2014
output: 2014-04-18
source: saturday august 22 1998
output: 1998-08-22
source: october 22 1995
output: 1995-12-22 # 错误 10 月
source: thursday february 29 1996
output: 1996-02-29
source: wednesday october 17 1979
output: 1979-10-17
source: 7 12 73
output: 1973-02-07 # 错误 12月
source: 9/30/01
output: 2001-05-00 # 错误 09-30
source: 22 may 2001
output: 2011-05-22 # 错误 2001
source: 7 march 1979
output: 1979-03-07
source: 19 feb 2013
output: 2013-02-19

标签:vocab,dense,标准化,建模,machine,bidirectional,格式,lstm,accuracy
来源: https://blog.csdn.net/qq_21201267/article/details/111300350

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有