标签:python tensorflow
如何在tensorflow中保存和恢复变量?
我遇到了问题.我的代码:
import tensorflow as tf
v1 = tf.Variable(tf.zeros([2, 2], dtype=tf.float32, name='v1'))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(v1)
save_path = saver.save(sess, 'model.ckpt')
print "model saved in file:", save_path
v1 = v1 + 1
print sess.run(v1)
saver = tf.train.import_meta_graph('model.ckpt.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
print sess.run(v1)
结果:
[[ 0. 0.]
[ 0. 0.]]
[[ 1. 1.]
[ 1. 1.]]
[[ 1. 1.]
[ 1. 1.]]
我希望得到:
[[ 0. 0.]
[ 0. 0.]]
[[ 1. 1.]
[ 1. 1.]]
[[ 0. 0.]
[ 0. 0.]]
我犯了什么错误?
请帮我理解.
解决方法:
您的代码中有两个主要问题:
>行v1 = v1 1创建一个新的TensorFlow Tensor并将其绑定到Python变量v1,但不会更改使用名称“v1”创建的TensorFlow变量中的值.因此,当您稍后调用sess.run(v1)时,您正在评估将原始变量加1的新张量,而不是从张量中读取值.
相反,要将变量添加到变量,您应该使用以下内容:
increment_op = v1.assign_add(tf.ones([2, 2]))
sess.run(increment_op)
> tf.train.import_meta_graph()
调用重新创建原始图形,并在此过程中向图形中添加新节点,包括新的tf.train.Saver
.当您尚未构建图形(或者没有程序可用于此图形)时,它非常有用.由于您已经构建了图形,因此只需要使用saver.restore(sess,’model.ckpt’).
以下程序应该产生您预期的行为:
import tensorflow as tf
v1 = tf.Variable(tf.zeros([2, 2], dtype=tf.float32, name='v1'))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(v1)
save_path = saver.save(sess, './model.ckpt')
print "model saved in file:", save_path
# Create an op to increment v1, run it, and print the result.
increment_op = v1.assign_add(tf.ones([2, 2]))
sess.run(increment_op)
print sess.run(v1)
# Restore from the checkpoint saved above.
saver.restore(sess, './model.ckpt')
print sess.run(v1)
标签:python,tensorflow 来源: https://codeday.me/bug/20190608/1198344.html
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。