python之tensorflow手把手实例讲解斑马线识别实现
一,斑马线的数据集
【python之tensorflow手把手实例讲解斑马线识别实现】数据集的构成:
test | train |
---|---|
zebra corssing:56 | zebra corssing:168 |
other:54 | other:164 |
import tensorflow as tffrom tensorflow.keras.preprocessing.image import ImageDataGeneratorimport numpy as npimport matplotlib.pyplot as pltimport keras
2.数据导入
train_dir=r'C:\Users\zx\深度学习\Zebra\train'test_dir=r'C:\Users\zx\深度学习\Zebra\test'train_datagen = ImageDataGenerator(rescale=1/255,rotation_range=10,#旋转horizontal_flip=True)train_generator = train_datagen.flow_from_directory(train_dir,(50,50),batch_size=1,class_mode='binary',shuffle=False)test_datagen = ImageDataGenerator(rescale=1/255)test_generator = test_datagen.flow_from_directory(test_dir,(50,50),batch_size=1,class_mode='binary',shuffle=False)
3.搭建模型
模型的建立仁者见智,可自己调节寻找更好的模型。
model = tf.keras.models.Sequential([# 第一层卷积,卷积核为,共16个,输入为150*150*1tf.keras.layers.Conv2D(16,(3,3),activation='relu',padding='same',input_shape=(50,50,3)),tf.keras.layers.MaxPooling2D((2,2)),# 第二层卷积,卷积核为3*3,共32个,tf.keras.layers.Conv2D(32,(3,3),activation='relu'),tf.keras.layers.MaxPooling2D((2,2)),# 第三层卷积,卷积核为3*3,共64个,tf.keras.layers.Conv2D(64,(3,3),activation='relu'),tf.keras.layers.MaxPooling2D((2,2)),# 第四层卷积,卷积核为3*3,共128个#tf.keras.layers.Conv2D(128,(3,3),activation='relu'),#tf.keras.layers.MaxPooling2D((2,2)),# 数据铺平tf.keras.layers.Flatten(),tf.keras.layers.Dense(32,activation='relu'),tf.keras.layers.Dense(16,activation='relu'),tf.keras.layers.Dense(2,activation='softmax')])print(model.summary())model.compile(optimize='adam',loss=tf.keras.losses.sparse_categorical_crossentropy,metrics=['acc'])
4,模型训练
history = model.fit(train_generator,epochs=20,verbose=1)model.save('./Zebra.h5')
模型训练过程:
文章图片
可以看到我们的模型在20轮的训练后acc从0.63上升到了0.96左右。
5,模型评估
model.evaluate(test_generator)
文章图片
#可视化plt.plot(history.history['acc'], label='accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.ylim([0.7, 1])plt.legend(loc='lower right')plt.title('acc')plt.show()
文章图片
6,模型预测
虽然我们的模型在训练过程中acc一度达到0.96,但测试集才是检验模型的唯一标准,在model.evaluate(test_generator)中的评分只有0.91左右,说明我们的模型已经能以很高的正确率来完成”斑马线“与“非斑马线”的二分类问题了,但我们还是要查看具体是哪些数据没有被模型正确得识别。
pred=model.predict(test_generator) #获取test集的输出filenames = test_generator.filenames#获取test数据的文件名
错误输出过程:
- 1,循环测试集长度,通过if语句先判断others还是zebra,再通过one-hot编码判断是否预测正确。
- 2,根据labels可知others': 0, 'zebra crossing': 1,以此来判断是否预测正确。
- 3,对 filenames[0]='others\\103.png',进行切片处理。
- 4,找到others的‘s'或 zebra crossing的‘g',使用find()在基础上+2为正切片的起点(样本编号前有'\'符号,故+2才能正确取出编号)。
- 5,如 :将filenames[i]的值赋给a,a[int(a.find('s')+2):]则表示为 'xx.png'。
- 6,将取出的样本编号与路径拼接,读取后作图。
- 7,break跳出循环。
for i in range(len(filenames)):if filenames[i][:6]=='others':if np.argmax(pred[i]) != 0:a=filenames[i]plt.figure()print('预测错误的图片:'+a[int(a.find('s')+2):])print('错误识别为"zebra crossing",正确类型是"others"')print('预测标签为:'+str(np.argmax(pred[i]))+',真实标签为:0')img = plt.imread('Zebra/test/others/'+a[int(a.find('s')+2):])plt.imshow(img)plt.title(a[int(a.find('s')+2):])plt.grid(False)breakif filenames[i][:6]=='zebra ':if np.argmax(pred[i]) != 1:b= filenames[i]plt.figure()print('预测错误的图片:'+b[int(b.find('g')+2):])print('错误识别为"others",正确类型是"zebra crossing"')print('预测标签为:'+str(np.argmax(pred[i]))+',真实标签为:1')img = plt.imread('Zebra/test/zebra crossing/'+b[int(b.find('g')+2):])plt.imshow(img)plt.title(b[int(b.find('g')+2):])plt.grid(False)break
文章图片
看到这个错误样本,我猜想可能是因为斑马线的部分只占了图像的一半左右,所以预测错误了。
这里是我做预测判断的思路,本可以不这么复杂的可以用test_generator.labels来获取数据的标签,再做判断。
test_generator.labels
文章图片
上面只输出了第一个错误的样本,所以接下来我们要看所有错误预测的样本
sum=0for i in range(len(filenames)):if filenames[i][:6]=='others':if np.argmax(pred[i]) != 0:a=filenames[i]print('预测错误的图片:'+a[int(a.find('s')+2):]+',错误识别为"zebra crossing",正确类型是"others"')sum=sum+1if filenames[i][:6]=='zebra ':if np.argmax(pred[i]) != 1:b= filenames[i]print('预测错误的图片:'+b[int(b.find('g')+2):]+',错误识别为"others",正确类型是"zebra crossing"')sum=sum+1print('错误率:'+str(sum/100)+'%')print('正确率:'+str((10000-sum)/100)+'%')
文章图片
三,分析 在构建模型时我尝试在最后一层只用一个神经元,用sigmoid激活函数,其他参数不变,在同样epochs=20的条件,也能很快收敛,达到很高的acc,测试集的评分也能在0.9左右,但是在最后输出全部错误样本的时候发现错误的样本远超过softmax,可能其中有些参数我没有根据sigmoid来调整,所以会有如此高的错误率,欢迎在评论区讨论。
到此这篇关于python之tensorflow手把手实例讲解斑马线识别实现的文章就介绍到这了,更多相关python tensorflow 斑马线识别内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
推荐阅读
- PMSJ寻平面设计师之现代(Hyundai)
- 太平之莲
- 闲杂“细雨”
- 七年之痒之后
- 深入理解Go之generate
- 由浅入深理解AOP
- 期刊|期刊 | 国内核心期刊之(北大核心)
- 生活随笔|好天气下的意外之喜
- 感恩之旅第75天
- python学习之|python学习之 实现QQ自动发送消息