ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

RNN简易训练

2021-05-07 18:04:03  阅读:171  来源: 互联网

标签:acc RNN 训练 train batch 简易 test tf mnist


from tensorflow.contrib.layers import fully_connected
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
n_steps=28
n_inputs=28
n_nerons=150
n_outputs=10

learning_rate=0.001

x=tf.placeholder(tf.float32,[None,n_steps,n_inputs])
y=tf.placeholder(tf.int32,[None])

basic_cell=tf.contrib.rnn.BasicRNNCell(num_units=n_nerons)
outputs,states=tf.nn.dynamic_rnn(basic_cell,x,dtype=tf.float32)

logits=fully_connected(states,n_outputs,activation_fn=None)
xentropy=tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=logits)

loss=tf.reduce_mean(xentropy)
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op=optimizer.minimize(loss)
correct=tf.nn.in_top_k(logits,y,1)
accuracy=tf.reduce_mean(tf.cast(correct,tf.float32))
init=tf.global_variables_initializer()

mnist=input_data.read_data_sets("/tmp/data/")
x_test=mnist.test.images.reshape(-1,n_steps,n_inputs)
y_test=mnist.test.labels

n_epochs=100
batch_size=150

with tf.Session() as sess:
   init.run()
   for epoch in range(n_epochs):
      for iteration in range(mnist.train.num_examples//batch_size):
         x_batch,y_batch=mnist.train.next_batch(batch_size)
         x_batch=x_batch.reshape(-1,n_steps,n_inputs)
         sess.run(training_op,feed_dict={x:x_batch,y:y_batch})
      acc_train=accuracy.eval(feed_dict={x:x_batch,y:y_batch})
      acc_test = accuracy.eval(feed_dict={x: x_test, y: y_test})
      print(epoch,'Train acc:',acc_train,'Test acc:',acc_test)

标签:acc,RNN,训练,train,batch,简易,test,tf,mnist
来源: https://blog.51cto.com/u_14540820/2759445

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有