ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Tensorflow 获取model中的变量列表

2021-07-02 21:03:02  阅读:266  来源: 互联网

标签:get variables slim 列表 获取 tf Tensorflow model


1、动态获取 

(1)朴素获取法
       1) 朴素获取可训练变量:t_vars = tf.trainable_variables()
       2)朴素获取全部变量,包含声明training=False变量:all_vars = tf.global_variables()
(2)使用tensorflow.contrib.slim
       1) 获取常规变量(是slim里面与model变量对应的一个类型):regular_variables = slim.get_variables()
       2)直接获取:vars = slim.get_variables_to_restore()
       3)slim用于筛选方法
            a. 通过name筛选: variables = slim.get_variables_by_name("d_")
            b. 通过name后缀筛选:variables = slim.get_variables_by_suffix("_b")
            c. 通过namespace筛选:variables = slim.get_variables(scope="layer1")
            d. 通过include和exclude筛选
                d0. variables_to_restore = slim.get_variables_to_restore(include=["d_"])
                d1. variables_to_restore = slim.get_variables_to_restore(exclude=["_w"])
(3) 离线获取(从一个已保存好的模型中获取var_list)
    1) 离线文件: checkpoint、model.data-xxxx、model.index、model.meta
    2) 将离线文件载入当前环境,变成动态获取
        

#记住,要先清空现有的图
#不然的话import_meta_graph会把原model里面的数据追加到现有的model中
#一片混乱
tf.reset_default_graph()
 
with tf.Session(graph=tf.get_default_graph()) as sess:
    new_saver = tf.train.import_meta_graph('e:/mytrain/results/20190227_01/model/model.meta')
    new_saver.restore(sess, 'e:/mytrain/results/20190227_01/model/model')
    #加载进来之后还不是为所欲为
    var_list=tf.global_variables()
        


    3) 直接从离线文件获取
        

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
         
#文件夹地址改成自己的
model_dir="'e:\\20190227_01\\mytrain\\results\\20190227_01\\model"
         
ckpt = tf.train.get_checkpoint_state(model_dir)
reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)
         
#返回一个dict= {'name':[shape] }
#例如 'd_w2/Adam':[4, 4, 32, 64]
var_to_shape_map = reader.get_variable_to_shape_map()
         
#我们可以用遍历的方式,取出字典里所有的key
for key in var_to_shape_map:
    print(key)        #key是str类型的
    #再用key去找这个tensor的值
    a=reader.get_tensor(key)
    print(type(a))    #输出: <class 'numpy.ndarray'>

标签:get,variables,slim,列表,获取,tf,Tensorflow,model
来源: https://blog.csdn.net/NOT_GUY/article/details/118417634

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有