标签:acc RNN 训练 train batch 简易 test tf mnist
from tensorflow.contrib.layers import fully_connected from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf n_steps=28 n_inputs=28 n_nerons=150 n_outputs=10 learning_rate=0.001 x=tf.placeholder(tf.float32,[None,n_steps,n_inputs]) y=tf.placeholder(tf.int32,[None]) basic_cell=tf.contrib.rnn.BasicRNNCell(num_units=n_nerons) outputs,states=tf.nn.dynamic_rnn(basic_cell,x,dtype=tf.float32) logits=fully_connected(states,n_outputs,activation_fn=None) xentropy=tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=logits) loss=tf.reduce_mean(xentropy) optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate) training_op=optimizer.minimize(loss) correct=tf.nn.in_top_k(logits,y,1) accuracy=tf.reduce_mean(tf.cast(correct,tf.float32)) init=tf.global_variables_initializer() mnist=input_data.read_data_sets("/tmp/data/") x_test=mnist.test.images.reshape(-1,n_steps,n_inputs) y_test=mnist.test.labels n_epochs=100 batch_size=150 with tf.Session() as sess: init.run() for epoch in range(n_epochs): for iteration in range(mnist.train.num_examples//batch_size): x_batch,y_batch=mnist.train.next_batch(batch_size) x_batch=x_batch.reshape(-1,n_steps,n_inputs) sess.run(training_op,feed_dict={x:x_batch,y:y_batch}) acc_train=accuracy.eval(feed_dict={x:x_batch,y:y_batch}) acc_test = accuracy.eval(feed_dict={x: x_test, y: y_test}) print(epoch,'Train acc:',acc_train,'Test acc:',acc_test)
标签:acc,RNN,训练,train,batch,简易,test,tf,mnist 来源: https://blog.51cto.com/u_14540820/2759445
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。