CNN在mnist数据集上实现

这次我们使用CNN中最经典的Lenet网络在mnist数据集上进行训练和预测。

  • 卷积NN
    主要有两部分组成,一部分是对输入图片特征提取,一部分是全连接网络,主要组成操作包括卷积、池化、激活等。
  • 【CNN在mnist数据集上实现】Lenet网络模型
    Lenet是提出比较早,能有效解决手写数字图片识别的卷积模型,模型结构如下:

    CNN在mnist数据集上实现
    文章图片
    0.PNG
其中,padding=valid代表非全0填充,输出图片尺寸=(输入尺寸-卷积核尺寸+1)/步长;padding=same代表全0填充,输出尺寸=输入尺寸/步长;pooling不改变深度。
对Lenet进行调整使其使用于mnist数据集,结构如下:

CNN在mnist数据集上实现
文章图片
Lenet_on_mnist.PNG
实现还是分三模块:forward,backwa,test,主要改变是在forward:

CNN在mnist数据集上实现
文章图片
lenet1.png
定义获得权重、偏执,增加对卷积,池化的函数。
CNN在mnist数据集上实现
文章图片
lenet2.png 按上层结构前向传播,返回预测值。
backward和test跟上一篇中改动不大,主要是要注意输入的大小:

CNN在mnist数据集上实现
文章图片
leb1.png
输入占位大小改变
CNN在mnist数据集上实现
文章图片
leb2.png
喂入的barch_size大小改变
同理,在test文件中,测试数据的大小也相应改变。
新手学习,欢迎指教!

    推荐阅读