TensorFlow技术解析与实战|TensorFlow技术解析与实战 9.5 RNN
# -*- coding:utf-8 -*-
import sys
import importlib
importlib.reload(sys)
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# 加载数据
mnist = input_data.read_data_sets("./", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
trX = trX.reshape(-1, 28, 28, 1)# 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1)# 28x28x1 input img
# 设置训练的超参数
lr = 0.001
training_iters = 100000
batch_size = 128
# 神经网络的参数
n_inputs = 28# 输入层的n
n_steps = 28# 28长度
n_hidden_units = 128# 隐藏层的神经元个数
n_classes = 10# 输出的数量,即分类的类别,0~9个数字,共有10个
# 输入数据占位符
x = tf.placeholder("float", [None, n_steps, n_inputs])
y = tf.placeholder("float", [None, n_classes])
# 定义权重
weights = {
'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),
'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))
}
biases = {
'in': tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),
'out': tf.Variable(tf.constant(0.1, shape=[n_classes, ]))
}
#定义RNN模型
【TensorFlow技术解析与实战|TensorFlow技术解析与实战 9.5 RNN】def RNN(X, weights, biases):
X = tf.reshape(X, [-1, n_inputs])#把输入的X转换成X ==》(128 batch * 28 steps, 28 inputs)
# 进入隐藏层
X_in = tf.matmul(X, weights['in']) + biases['in']# (128 batch * 28 steps, 128 hidden)
X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units])# 128 batch , 28 steps, 128 hidden
# 这里采用基本的LSTM循环网络单元:basic LSTM Cell
lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32) #lstm单元由两个部分组成:(c_state, h_state)
# dynamic_rnn接收张量(batch, steps, inputs)或者(steps, batch, inputs)作为X_in
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)
results = tf.matmul(final_state[1], weights['out']) + biases['out']
return results
# 定义损失函数和优化器,优化器采用AdamOptimizer
pred = RNN(x, weights, biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train_op = tf.train.AdamOptimizer(lr).minimize(cost)
# 定义模型预测结果及准确率计算方法
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# 训练数据及评估模型
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 0
while step * batch_size < training_iters:
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs])
sess.run([train_op], feed_dict={x:batch_xs, y:batch_ys,})
if step % 20 == 0:
print(sess.run(accuracy, feed_dict={x:batch_xs, y:batch_ys,}))
step += 1
0.2578125
0.671875
0.7578125
0.8203125
0.8984375
0.9296875
0.9140625
0.90625
0.875
0.9296875
0.9453125
0.9296875
0.9609375
0.921875
0.9296875
0.9609375
0.9453125
0.890625
0.9296875
0.9375
0.953125
0.9765625
0.9375
0.9375
0.9609375
0.9609375
0.96875
0.9609375
0.96875
0.9609375
0.96875
0.9453125
0.9609375
0.9921875
0.9765625
0.9765625
0.9765625
0.96875
0.953125
0.9765625
推荐阅读
- GIS跨界融合赋能多领域技术升级,江淮大地新应用成果喜人
- 深入浅出谈一下有关分布式消息技术(Kafka)
- Quartz|Quartz 源码解析(四) —— QuartzScheduler和Listener事件监听
- Java内存泄漏分析系列之二(jstack生成的Thread|Java内存泄漏分析系列之二:jstack生成的Thread Dump日志结构解析)
- [源码解析]|[源码解析] NVIDIA HugeCTR,GPU版本参数服务器---(3)
- Android系统启动之init.rc文件解析过程
- 小程序有哪些低成本获客手段——案例解析
- Spring源码解析_属性赋值
- 2月2日日课总结(基因技术)
- NAT(网络地址转换技术)