ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

tensorflow 对Model检测点的操作、model.get_layer、从 checkpoint加载权重、set_weights、model层属性获取

2021-06-09 18:04:51  阅读:480  来源: 互联网

标签:bert layer tensor get variables model 检测点


文章目录

model层属性获取

在tensorflow中,要想获取层的输出的各种信息,可以先获取层对象,再通过层对象的属性获取层输出的其他特性.

获取model对应层的方法为:

get_layer(self, name=None, index=None):

函数功能:根据层的名称(这个名称具有唯一性)或者索引号检索model获取对应的层.

获取层输出的其他特性

  1. model.get_layer(index=0).output # 输出张量
  2. model.get_layer(index=0).output_shape #各自的形状
  3. model.get_layer(index=0).input # 输出张量
  4. model.get_layer(index=0).output_shape #各自的形状
  5. #该层有多个节点时(node_index为节点序号):
  6. layer.get_input_at(node_index)
  7. layer.get_output_at(node_index)
  8. layer.get_input_shape_at(node_index)
  9. layer.get_output_shape_at(node_index)
  10. model.get_layer(“word_embeddings”).set_weights(weights) #将权重加载到该层
  11. model.get_layer(“word_embeddings”).get_weights() #返回层的权重(numpy array)
  12. config = model.get_layer(“word_embeddings”).get_config() #保存该层的配置

检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

load_checkpoint

tf.train.load_checkpoint(ckpt_dir_or_file):

在ckpt_dir_or_file中找到的检查点返回’ CheckpointReader ’
如果’ ckpt_dir_or_file '解析为具有多个检查点的目录,则返回最新检查点的reader。

variables = tf.train.load_checkpoint(init_checkpoint)

从Checkpoint对象获取张量:

variables.get_tensor(“bert/embeddings/word_embeddings”)

bert 中load_checkpoint并且get_layer().set_weights操作

variables = tf.train.load_checkpoint(init_checkpoint)
# embedding weights
model._encoder_layer.get_layer("word_embeddings").set_weights([
        variables.get_tensor("bert/embeddings/word_embeddings")])
model._encoder_layer.get_layer("position_embeddings").set_weights([
        variables.get_tensor("bert/embeddings/position_embeddings")])
model._encoder_layer.get_layer("type_embeddings").set_weights([
        variables.get_tensor("bert/embeddings/token_type_embeddings")])

model._encoder_layer.get_layer("embeddings/layer_norm").set_weights([
        variables.get_tensor("bert/embeddings/LayerNorm/gamma"),
        variables.get_tensor("bert/embeddings/LayerNorm/beta")
])

model._encoder_layer.get_layer("embedding_projection").set_weights([
        variables.get_tensor("bert/encoder/embedding_hidden_mapping_in/kernel"),
        variables.get_tensor("bert/encoder/embedding_hidden_mapping_in/bias")
])
# multi attention weights

    for i in range(model._config['bert_config'].num_hidden_layers):
        model._encoder_layer.get_layer("transformer/layer_{}".format(i)).set_weights([
            tf.reshape(variables.get_tensor(
                "bert/encoder/layer_{}/attention/self/query/kernel".format(i)),
                [model.bert_config.hidden_size, model.bert_config.num_attention_heads, -1]),
            tf.reshape(
                variables.get_tensor("bert/encoder/layer_{}/attention/self/query/bias".format(i)),
                [model.bert_config.num_attention_heads, -1]),
            tf.reshape(variables.get_tensor(
                "bert/encoder/layer_{}/attention/self/key/kernel".format(i)),
                [model.bert_config.hidden_size, model.bert_config.num_attention_heads, -1]),
            tf.reshape(
                variables.get_tensor("bert/encoder/layer_{}/attention/self/key/bias".format(i)),
                [model.bert_config.num_attention_heads, -1]),
            tf.reshape(variables.get_tensor(
                "bert/encoder/layer_{}/attention/self/value/kernel".format(i)),
                [model.bert_config.hidden_size, model.bert_config.num_attention_heads, -1]),
            tf.reshape(
                variables.get_tensor("bert/encoder/layer_{}/attention/self/value/bias".format(i)),
                [model.bert_config.num_attention_heads, -1]),
            tf.reshape(variables.get_tensor(
                "bert/encoder/layer_{}/attention/output/dense/kernel".format(i)),
                [model.bert_config.num_attention_heads, -1, model.bert_config.hidden_size]),
            variables.get_tensor("bert/encoder/layer_{}/attention/output/dense/bias".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/attention/output/LayerNorm/gamma".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/attention/output/LayerNorm/beta".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/intermediate/dense/kernel".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/intermediate/dense/bias".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/output/dense/kernel".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/output/dense/bias".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/output/LayerNorm/gamma".format(i)),
            variables.get_tensor("bert/encoder/layer_{}/output/LayerNorm/beta".format(i)),
        ])

    model._encoder_layer.get_layer("pooler_transform").set_weights([
        variables.get_tensor("bert/pooler/dense/kernel"),
        variables.get_tensor("bert/pooler/dense/bias"),
    ])

tf.train.list_variables(init_checkpoint)
#列出检查点中变量的检查点键和形状。
#bert 例子

init_vars = tf.train.list_variables(init_checkpoint)
    for name, shape in init_vars:
        if name.startswith("bert"):
            print(f"{name}, shape={shape}, *INIT FROM CKPT SUCCESS*")
import tensorflow as tf
import os
ckpt_directory = "/tmp/training_checkpoints/ckpt"
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(ckpt, ckpt_directory, max_to_keep=3)
train_and_checkpoint(model, manager)
tf.train.list_variables(manager.latest_checkpoint)

保存检测点

x_train,y_train,x_test,y_test=process_data()
model=mode(x_train,y_train,x_test,y_test)
checkpoint=tf.train.Checkpoint(A=model)    #保存model
checkpoint.save('./checkpoint/01.ckpt')   #在源文件夹建立一个checkpoint文件夹,保存的是文件目录加文件前缀

标签:bert,layer,tensor,get,variables,model,检测点
来源: https://blog.csdn.net/qq_43940950/article/details/117750809

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有