循环神经网络系列基于LSTM的MNIST手写体识别

眼前多少难甘事,自古男儿当自强。这篇文章主要讲述循环神经网络系列基于LSTM的MNIST手写体识别相关的知识,希望能为你提供帮助。


我们知道循环神经网络是用来处理包含序列化数据的相关问题,所有若要利用循环神经网络来解决某类问题,那么首先要做的就是将原始数据集序列化,然后处理成某个深度学习框架所接受的数据输入格式(比如Tensorflow).


1.数据预处理
我们知道MNIST数据集中的每张图片形状都是??[28*28]???的,那么如何将其序列化呢?想要知道怎么序列化,那还得从LSTM接受怎样的输入说起。由于??前文??我们说到,一个LSTM单元可以按时间维度进行展开处理,那么对于一个黑白图片来说该怎么展开呢?并且展开后必须得有前后的序列关系。 最直接的想法当然就是每个像素点当成一个部分,直接展开成784个LSTM单元。这种做法理论上当然没有什么问题,只是显得略微有点粗暴了;那我们就稍微缓和一点,按行或者按列来分割成28行(列),然后将这28个部分看成是序列。如下图所示:

循环神经网络系列基于LSTM的MNIST手写体识别

文章图片

所以,对于整个数据集来说:我们第一步要做的就是将其??reshape???成??[batchsize,high,width]???的这种形式;然后第二步就是将其按行分割成28个部分,变成??[timestep,batchsize,dim]???。以??batchsize = 4??为例,可以画出如下示意图:
循环神经网络系列基于LSTM的MNIST手写体识别

文章图片

温馨提示:4种颜色分别表示4张图片
因此,这部分对应代码就是:
x = tf.placeholder(dtype=tf.float32, shape=[None, 784], name=input-x)
y = tf.placeholder(dtype=tf.int32, shape=[None], name=input-y)
x_reshape = tf.reshape(x, shape=[-1, DIM, DIM], name=reshape-x)
x_tranpose = tf.transpose(x_reshape, perm=[1, 0, 2], name=transpose-x)


其中第4行就表示按行分割成28个部分,把所有的第i行都放在一起;
对于得到的这种形式的数据,我们在喂给展开后的LSTM时是长下面这个样的:
循环神经网络系列基于LSTM的MNIST手写体识别

文章图片

2.搭建网络
从??上文??的介绍可知,经过LSTM处理后,输出结果的格式是:??[timesteps,batchsize,outputsize]??。又由于我们做的仅仅是分类任务,所以我们接下来就取最后一个??timesteps??的输出作为整个LSTM的输出即可。同时,作为分类任务,我们需要得到每个类别的预测概率,因此还需要再LSTM的输出结果后加上一个??softmax??层,到此网络结构的搭建就完了。
循环神经网络系列基于LSTM的MNIST手写体识别

文章图片

这部分对应代码就是:
def lstm(inputs):
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=OUTPUT_SIZE)
h0 = cell.zero_state(batch_size=tf.shape(inputs)[1], dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs=inputs, initial_state=h0, time_major=True)
return outputs[-1]

y_ = lstm(x_tranpose)
with tf.name_scope(weighted-softmax):
weights = tf.Variable(tf.truncated_normal(shape=[OUTPUT_SIZE, OUTPUT_SIZE], stddev=0.1), dtype=tf.float32)
bias = tf.Variable(tf.constant(0, shape=[OUTPUT_SIZE], dtype=tf.float32))
logits = tf.nn.xw_plus_b(y_, weights, bias, name=softmax)


??源码戳此处??
更多内容欢迎扫码关注公众号月来客栈!
循环神经网络系列基于LSTM的MNIST手写体识别

文章图片



【循环神经网络系列基于LSTM的MNIST手写体识别】


    推荐阅读