千磨万击还坚劲,任尔东西南北风。这篇文章主要讲述20210608 TensorFlow 实现数字图片分类相关的知识,希望能为你提供帮助。
0-1 导包
import warnings
warnings.filterwarnings("ignore")import keras
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
1-1 构造数据
调用接口去下载数据
mnist = keras.datasets.mnist# 导入 mnist
(train_images, train_labels),(test_images,test_labels) = mnist.load_data()
60000条训练集,10000条测试集
print("train image shape:",train_images.shape,"train label",train_labels.shape)
print("test image shape:",test_images.shape,"test label",test_labels.shape)
-->
train image shape: (60000, 28, 28) train label (60000,)
test image shape: (10000, 28, 28) test label (10000,)
60000条训练集分为训练的图片和标签,图片是手写的数字图片,长和宽都是 28,标签就是数字集,测试集是 10000 个图片,测试标签也是 10000
1-1-1 显示数据和标签
print("image data:",train_images[0])
print("label data:",train_labels[0])
# 数据太多太大,这里便不展示了
# label data: 5
1-1-2
defplot_img(img):
plt.imshow(img.reshape(28,28),cmap="binary")# camp = "binary" 使用灰度图表示
plt.show()
# 查看图片
plot_img(train_images[3])
plot_img(train_images[0])
文章图片
1-1-3 现在进行一个 0 1 的二分类
data_0 = []
data_1 = []
# zip 把里面的子数据 拿出来了,img是(28,28);label 标签是一个值
for img,label in zip(train_images, train_labels):
if label == 0:
# 首先需要把图片拍平,现在的训练是一种类似 DNN 神经网络的模式
# 所以无法传入 28*28 的图片进行训练,reshape(28*28),将二阶变为一阶的 784
# 除以 255 的作用是,约束到 0 到 1 之间
img = img.reshape(28*28)/255
# 将图片和 label 放一起,变成 785 个数据,赋给 img
img = np.append(img,label)
data_0.append(img)
if label == 1:
img = img.reshape(28*28)/255
img = np.append(img,label)
data_1.append(img)
data_0 = np.array(data_0)
data_1 = np.array(data_1)
all_data = https://www.songbingjia.com/android/np.concatenate([data_0,data_1])
print(all_data.shape)
# --> (12665, 785)
# 前 784 是像素,最后的 1 是标签
np.random.shuffle(all_data)# np.random.shuffle 打乱顺序,避免前面全是0,后面全是 1
# 切分训练集和数据集,按理说 20% 测试集,80%训练集,这里只是简单的 前面的作为训练集,后200个作为测试集
train_data = https://www.songbingjia.com/android/all_data[:-200]
test_data = all_data[-200:]
2-1 数据分块
# 生成器
def gen_batch(data,batch_size):
np.random.shuffle(data)# 打乱,增加随机性
for i in range(len(data) // batch_size):
cursor = batch_size * i
batch_data = https://www.songbingjia.com/android/data[cursor : cursor + batch_size]
x = batch_data[:, 0:784]# 第一维度全取,第二维度取前 784 个像素,最后一个是标签
y = batch_data[:, 784]
yield x,y.reshape(-1,1)
remainder = len(data) % batch_size
if remainder != 0:
x, y = data[-remainder:, 0:784], data[-remainder:, 784]
yield x, y.reshape(-1,1)
for x_,y_ in gen_batch(train_data,128):
print(x_.shape)
print(y_.shape)
print(/'-------\')
break
# -->
# (128, 784)
# (128, 1)
# -------
3-1 超参数
learing_rate = 0.01
num_train_epochs = 50# 循环训练集总轮数
display_per_step = 100
batch_size = 128
4-1 计算图
graph = tf.Graph()
with graph.as_default():
x = tf.placeholder(shape=[None,784], dtype=tf.float32, name=\'x\')
y = tf.placeholder(shape=[None,1], dtype=tf.float32, name=\'y\')
w =tf.Variable(tf.ones(shape=[784,1]), dtype=tf.float32)
b =tf.Variable(0, dtype=tf.float32)
logits = tf.matmul(x, w) + b
y_pred = tf.sigmoid(logits)
# 定义loss
with tf.name_scope("loss"):
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits), name="calculate_loss")
# 定义优化器
with tf.name_scope("SGD"):
optimizer = tf.train.GradientDescentOptimizer(learing_rate)
train_step = optimizer.minimize(loss)
# 定义正确率
with tf.name_scope("calculation_accuracy"):
res_pred = tf.cast(tf.greater_equal(y_pred, 0.5), dtype=tf.float32)
acc = tf.reduce_mean(tf.cast(tf.equal(res_pred, y), dtype=tf.float32))
5-1 运行计算图
with tf.Session(graph=graph) as sess:
init = tf.global_variables_initializer()
sess.run(init)
step = 0
for epoch in range(num_train_epochs):
for x_, y_ in gen_batch(train_data, batch_size):
step += 1
_, l, acc_ = sess.run([train_step, loss, acc], feed_dict={x: x_, y: y_})
if step % display_per_step == 0:
print("step: {:>
4}, loss: {:.4}, acc: {:.4%}".format(step, l, acc_))
print(\'training over\')
x_test,y_test = next(gen_batch(test_data,200))# 取出全部测试集数据
# 查看测试集的 loss 和正确率
loss_test, acc_test = sess.run([loss, acc],feed_dict={x: x_test, y: y_test})
print("test loss is {:.4}, acc is {:.4%}".format(loss_test, acc_test))
res_weights = sess.run([w, b])
# res_weights 存在的目的是进行模型的预测,给出一张图片,判断能否正确预测
-->
step:100, loss: 43.14, acc: 60.1562%
step:200, loss: 34.86, acc: 50.7812%
step:300, loss: 19.21, acc: 52.3438%
……
step: 4700, loss: 0.09711, acc: 96.8750%
step: 4800, loss: 0.1117, acc: 97.6562%
step: 4900, loss: 0.0001993, acc: 100.0000%
training over
test loss is 0.04518, acc is 99.0000%
6-1 模型预测
def plot_img(img):
plt.imshow(img.reshape(28,28),cmap="binary")
plt.show()
test_img = test_data[7,:784]
plot_img(test_img)
文章图片
6-2
graph = tf.Graph()
with graph.as_default():
x = tf.placeholder(shape=[None,784], dtype=tf.float32, name=\'x\')
y = tf.placeholder(shape=[None,1], dtype=tf.float32, name=\'y\')
# w 和 b 是前面训练好的值
w =tf.Variable(res_weights[0], dtype=tf.float32)
b =tf.Variable(res_weights[1], dtype=tf.float32)
logits = tf.matmul(x, w) + b
y_pred = tf.sigmoid(logits)
# 定义loss
with tf.name_scope("loss"):
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits), name="calculate_loss")
# 定义优化器
with tf.name_scope("SGD"):
optimizer = tf.train.GradientDescentOptimizer(learing_rate)
train_step = optimizer.minimize(loss)
# 定义正确率
with tf.name_scope("calculation_accuracy"):
res_pred = tf.cast(tf.greater_equal(y_pred, 0.5), dtype=tf.float32)
acc = tf.reduce_mean(tf.cast(tf.equal(res_pred, y), dtype=tf.float32))
with tf.Session(graph=graph) as sess:
# 初始化权重,将前面训练好的权重,放到计算图里
init = tf.global_variables_initializer()
sess.run(init)
# test_data[7,:784].shape 是 (784,)
# 但是x = tf.placeholder(shape=[None,784], dtype=tf.float32, name=\'x\')
# 所以需要 reshape 成为 (1,-1),也就是第一维是 1,第二维就是 784/1
# 所以 当前这里图片 的 shape 就是 (1,784)
x_ = test_data[7,:784].reshape(1,-1)
# 从这地方可以看出,预测输出的时候,并没有使用标签,所以传参时只传 x 就可以
res_p = sess.run(res_pred, feed_dict={x: x_})
print(res_p)
【20210608 TensorFlow 实现数字图片分类】# --> [[0.]]
预测出的结果是 0,所以简单的网络结构确实可以对测试集的手写数字图片进行预测,这就是一个数字的图片分类
部分代码解释:
1. 1-1-3 中的 zip()
https://blog.51cto.com/u_15149862/2847102
2. 1-1-3 中列表的使用,比如 train_data = https://www.songbingjia.com/android/all_data[:-200]
https://blog.51cto.com/u_15149862/2704954
3. 2-1 中的生成器
https://blog.51cto.com/u_15149862/2844458
4. TensorFlow 基本用法
https://blog.51cto.com/u_15149862/2825353
推荐阅读
- python基础篇(二十一)——文件和异常(上)
- 保姆级利用Github搭建自己的个人博客,看完就会
- 如何在Java中轻松地从INI文件读取(解析)和写入INI文件
- 如何使用Python检查字符串是否是回文
- 如何在Java中生成具有自定义长度的随机字母数字字符串
- 如何使用WxPython创建HTML文件查看器
- 附录(基于Chrome DevTools网络面板的Web调试代理)
- 如何在Python中创建QR Code图像或SVG
- 为什么要学习Python编程语言()