Mnist数据集用神经网络处理-tensorflow实现
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('./mnist/', one_hot=True)INPUT_SIZE = 28
HIDDEN1_SIZE = 256
HIDDEN2_SIZE = 256
LR = 0.002
N_CLASSES = 10
TRAINING_EPOCH = 25
BATCH_SIZE = 50X = tf.placeholder(tf.float32, [None, INPUT_SIZE * INPUT_SIZE])
Y = tf.placeholder(tf.float32, [None, N_CLASSES])W1 = tf.Variable(tf.random_normal([INPUT_SIZE * INPUT_SIZE, HIDDEN1_SIZE]))
b1 = tf.Variable(tf.random_normal([HIDDEN1_SIZE]))
L1=tf.nn.sigmoid(tf.matmul(X,W1) + b1)W2 = tf.Variable(tf.random_normal([HIDDEN1_SIZE, HIDDEN2_SIZE]))
b2 = tf.Variable(tf.random_normal([HIDDEN2_SIZE]))
L2=tf.nn.sigmoid(tf.matmul(L1,W2) + b2)W3 = tf.Variable(tf.random_normal([HIDDEN1_SIZE, HIDDEN2_SIZE]))
b3 = tf.Variable(tf.random_normal([HIDDEN2_SIZE]))
L3=tf.nn.sigmoid(tf.matmul(L2,W3) + b3)W = tf.Variable(tf.random_normal([HIDDEN2_SIZE, N_CLASSES]))
b = tf.Variable(tf.random_normal([N_CLASSES]))
hypothesis = tf.nn.sigmoid(tf.matmul(L3,W) + b)
#损失函数
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = hypothesis,labels = Y))
#ada梯度下降
optimizer=tf.train.AdamOptimizer(learning_rate=LR).minimize(cost)sess = tf.Session()
sess.run(tf.global_variables_initializer())#全部初始化print('Learning stared. It takes sometime.')
for epoch in range(TRAINING_EPOCH):#25次
avg_cost = 0
total_batch = int(mnist.train.num_examples / BATCH_SIZE)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE)
c, _, = sess.run([cost, optimizer], feed_dict= {X: batch_xs, Y: batch_ys})
avg_cost += c / total_batch
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(avg_cost))
print('Learning Finished!')correct_prediction = tf.equal(tf.argmax(hypothesis, 1), tf.argmax(Y, 1))#查看预测和正确相等的
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))#bool型转为float
print('Accuracy:', sess.run(accuracy, feed_dict={X: mnist.test.images, Y: mnist.test.labels}))#显示准确率
【Mnist数据集用神经网络处理-tensorflow实现】
文章图片
推荐阅读
- Docker应用:容器间通信与Mariadb数据库主从复制
- 使用协程爬取网页,计算网页数据大小
- Java|Java基础——数组
- Python数据分析(一)(Matplotlib使用)
- Jsr303做前端数据校验
- Spark|Spark 数据倾斜及其解决方案
- 数据库设计与优化
- 爬虫数据处理HTML转义字符
- 数据库总结语句
- MySql数据库备份与恢复