Python|Python Opencv使用ann神经网络识别手写数字功能

opencv中也提供了一种类似于Keras的神经网络,即为ann,这种神经网络的使用方法与Keras的很接近。
关于mnist数据的解析,读者可以自己从网上下载相应压缩文件,用python自己编写解析代码,由于这里主要研究knn算法,为了图简单,直接使用Keras的mnist手写数字解析模块。
本次代码运行环境为:
python 3.6.8
opencv-python 4.4.0.46
opencv-contrib-python 4.4.0.46
下面的代码为使用ann进行模型的训练:

from keras.datasets import mnistfrom keras import utilsimport cv2import numpy as np#opencv中ANN定义神经网络层def create_ANN():ann=cv2.ml.ANN_MLP_create()#设置神经网络层的结构 输入层为784 隐藏层为80 输出层为10ann.setLayerSizes(np.array([784,64,10]))#设置网络参数为误差反向传播法ann.setTrainMethod(cv2.ml.ANN_MLP_BACKPROP)#设置激活函数为sigmoidann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM)#设置训练迭代条件 #结束条件为训练30次或者误差小于0.00001ann.setTermCriteria((cv2.TermCriteria_EPS|cv2.TermCriteria_COUNT,100,0.0001))return ann#计算测试数据上的识别率def evaluate_acc(ann,test_images,test_labels):#采用的sigmoid激活函数,需要对结果进行置信度处理 #对于大于0.99的可以确定为1 对于小于0.01的可以确信为0test_ret=ann.predict(test_images)#预测结果是一个元组test_pre=test_ret[1]#可以直接最大值的下标 (10000,)test_pre=test_pre.argmax(axis=1)true_sum=(test_pre==test_labels)return true_sum.mean()if __name__=='__main__':#直接使用Keras载入的训练数据(60000, 28, 28) (60000,)(train_images,train_labels),(test_images,test_labels)=mnist.load_data()#变换数据的形状并归一化train_images=train_images.reshape(train_images.shape[0],-1)#(60000, 784)train_images=train_images.astype('float32')/255test_images=test_images.reshape(test_images.shape[0],-1)test_images=test_images.astype('float32')/255#将标签变为one-hot形状 (60000, 10) float32train_labels=utils.to_categorical(train_labels)#测试数据标签不用变为one-hot (10000,)test_labels=test_labels.astype(np.int)#定义神经网络模型结构ann=create_ANN()#开始训练ann.train(train_images,cv2.ml.ROW_SAMPLE,train_labels)#在测试数据上测试准确率print(evaluate_acc(ann,test_images,test_labels))#保存模型ann.save('mnist_ann.xml')#加载模型myann=cv2.ml.ANN_MLP_load('mnist_ann.xml')

训练100次得到的准确率为0.9376,可以接着增加训练次数或者提高神经网络的层次结构深度来提高准确率。
使用ann神经网络的模型结构非常小,因为只是保存了权重参数。
Python|Python Opencv使用ann神经网络识别手写数字功能
文章图片

可以看到整个模型文件的大小才1M,而svm的大小为十多兆,knn的为几百兆,因此使用ann神经网络更加适合部署在客户端上。
接下来使用ann进行图片的测试识别:
import cv2import numpy as npif __name__=='__main__':#读取图片img=cv2.imread('shuzi.jpg',0)img_sw=img.copy()#将数据类型由uint8转为float32img=img.astype(np.float32)#图片形状由(28,28)转为(784,)img=img.reshape(-1,)#增加一个维度变为(1,784)img=img.reshape(1,-1)#图片数据归一化img=img/255#载入ann模型ann=cv2.ml.ANN_MLP_load('minist_ann.xml')#进行预测img_pre=ann.predict(img)#因为激活函数sigmoid,因此要进行置信度处理ret=img_pre[1]ret[ret>0.9]=1ret[ret<0.1]=0print(ret)cv2.imshow('test',img_sw)cv2.waitKey(0)

运行程序,结果如下,可见该模型正确识别了数字0.
Python|Python Opencv使用ann神经网络识别手写数字功能
文章图片

【Python|Python Opencv使用ann神经网络识别手写数字功能】到此这篇关于Python Opencv使用ann神经网络识别手写数字的文章就介绍到这了,更多相关python opencv识别手写数字内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

    推荐阅读