【源码】MATLAB深度学习实战DEMO
深度学习DEMO提供了三个实现目标识别的卷积神经网络CNN示例。三个例子分别为:
-
从零开始学习如何建立CNN;
-
使用已经训练过的模型(迁移学习);
-
用于特征提取的神经网络训练。
运行以上示例需要安装MATLAB自带的GPU和并行计算工具箱,DEMO 3还需要安装统计与机器学习工具箱。
下面简单介绍DEMO 1:从零开始学习如何建立CNN。
-
运行DownloadCIFAR10.m文件,下载DEMO运行所需要的数据。
-
执行以下代码将训练数据导入MATLAB;
%Feel free to choose which ever you like best!
categories= {‘Deer’,‘Dog’,‘Frog’,‘Cat’};
rootFolder= ‘cifar10Train’;
【【源码】MATLAB深度学习实战DEMO】imds= imageDatastore(fullfile(rootFolder, categories), …
'LabelSource', 'foldernames');
-
定义CNN的各层网络,这里可以根据自己的需要调整参数,下面的代码只是一个示例。
conv1= convolution2dLayer(5,varSize,‘Padding’,2,‘BiasLearnRateFactor’,2);
conv1.Weights= gpuArray(single(randn([5 5 3 varSize])*0.0001));
fc1= fullyConnectedLayer(64,‘BiasLearnRateFactor’,2);
fc1.Weights= gpuArray(single(randn([64 576])*0.1));
fc2= fullyConnectedLayer(4,‘BiasLearnRateFactor’,2);
fc2.Weights= gpuArray(single(randn([4 64])*0.1));
layers= [
imageInputLayer([varSize varSize 3]);
conv1;
maxPooling2dLayer(3,'Stride',2);
reluLayer();
convolution2dLayer(5,32,'Padding',2,'BiasLearnRateFactor',2);
reluLayer();
averagePooling2dLayer(3,'Stride',2);
convolution2dLayer(5,64,'Padding',2,'BiasLearnRateFactor',2);
reluLayer();
averagePooling2dLayer(3,'Stride',2);
fc1;
reluLayer();
fc2;
softmaxLayer()classificationLayer()];
-
设置CNN的训练选项,这些参数设置会严重影响CNN的工作性能,在设置之前应当准确理解这些参数的物理意义。
'InitialLearnRate', 0.001, ...'LearnRateSchedule', 'piecewise', ...'LearnRateDropFactor', 0.1, ...'LearnRateDropPeriod', 8, ...'L2Regularization', 0.004, ...'MaxEpochs', 10, ...'MiniBatchSize', 100, ...'Verbose', true);
-
开始训练CNN,训练时间长短与具体的硬件设备相关,一般会花费数分钟或以上。
Training on singleGPU.
Initializing imagenormalization.
|=========================================================================================|
| Epoch | Iteration | Time Elapsed | Mini-batch | Mini-batch | Base Learning|
| | | (seconds) | Loss | Accuracy | Rate |
|=========================================================================================|
| 1 | 1 | 0.25 | 1.3862 | 24.00% | 0.0010 |
| 1 | 50 | 1.86 | 1.2571 | 39.00% | 0.0010 |
| 1 | 100 | 3.35 | 1.2376 | 39.00% | 0.0010 |
| 1 | 150 | 4.90 | 1.1451 | 50.00% | 0.0010 |
| 1 | 200 | 6.39 | 1.0797 | 59.00% | 0.0010 |
| 2 | 250 | 8.03 | 0.8069 | 69.00% | 0.0010 |
| 2 | 300 | 9.64 | 1.1253 | 51.00% | 0.0010 |
| 2 | 350 | 11.20 | 0.9872 | 59.00% | 0.0010 |
| 2 | 400 | 12.75 | 0.9490 | 59.00% | 0.0010 |
| 3 | 450 | 14.31 | 0.7405 | 70.00% | 0.0010 |
| 3 | 500 | 15.77 | 0.9592 | 59.00% | 0.0010 |
| 3 | 550 | 17.28 | 0.9337 | 61.00% | 0.0010 |
| 3 | 600 | 18.77 | 0.8383 | 65.00% | 0.0010 |
| 4 | 650 | 20.30 | 0.6693 | 71.00% | 0.0010 |
| 4 | 700 | 21.80 | 0.8787 | 63.00% | 0.0010 |
| 4 | 750 | 23.27 | 0.8892 | 63.00% | 0.0010 |
| 4 | 800 | 24.76 | 0.7295 | 69.00% | 0.0010 |
| 5 | 850 | 26.28 | 0.6321 | 72.00% | 0.0010 |
| 5 | 900 | 27.77 | 0.8034 | 71.00% | 0.0010 |
| 5 | 950 | 29.26 | 0.8285 | 68.00% | 0.0010 |
| 5 | 1000 | 30.75 | 0.6893 | 69.00% | 0.0010 |
| 6 | 1050 | 32.27 | 0.5741 | 76.00% | 0.0010 |
| 6 | 1100 | 33.74 | 0.7280 | 73.00% | 0.0010 |
| 6 | 1150 | 35.20 | 0.8312 | 68.00% | 0.0010 |
| 6 | 1200 | 36.69 | 0.5876 | 77.00% | 0.0010 |
| 7 | 1250 | 38.25 | 0.5598 | 75.00% | 0.0010 |
| 7 | 1300 | 39.80 | 0.6704 | 77.00% | 0.0010 |
| 7 | 1350 | 41.37 | 0.7792 | 68.00% | 0.0010 |
| 7 | 1400 | 42.87 | 0.5495 | 78.00% | 0.0010 |
| 8 | 1450 | 44.40 | 0.5561 | 79.00% | 0.0010 |
| 8 | 1500 | 45.89 | 0.6032 | 81.00% | 0.0010 |
| 8 | 1550 | 47.39 | 0.7548 | 68.00% | 0.0010 |
| 8 | 1600 | 48.90 | 0.5371 | 78.00% | 0.0010 |
| 9 | 1650 | 50.49 | 0.5247 | 80.00% | 0.0001 |
| 9 | 1700 | 52.02 | 0.5989 | 79.00% | 0.0001 |
| 9 | 1750 | 53.60 | 0.6982 | 72.00% | 0.0001 |
| 9 | 1800 | 55.17 | 0.4448 | 78.00% | 0.0001 |
| 10 | 1850 | 56.71 | 0.4927 | 79.00% | 0.0001 |
| 10 | 1900 | 58.23 | 0.5630 | 80.00% | 0.0001 |
| 10 | 1950 | 59.71 | 0.6843 | 73.00% | 0.0001 |
| 10 | 2000 | 61.18 | 0.4486 | 79.00% | 0.0001 |
|=========================================================================================|
-
将测试验证数据导入MATLAB。
imds_test= imageDatastore(fullfile(rootFolder, categories), …
'LabelSource', 'foldernames');
-
测试结果输出,通过随机读取一幅图片进行分类测试,如果图片的标题为绿色,则预测结果正确;如果为红色,则预测结果错误。
ii= randi(4000);
im= imread(imds_test.Files{ii});
imshow(im);
iflabels(ii) ==imds_test.Labels(ii)
colorText = ‘g’;
else
colorText = 'r';
end
title(char(labels(ii)),‘Color’,colorText);
DEMO下载地址:
http://page2.dfpan.com/fs/9lc2j2821f29b1676d7/
更多精彩文章请关注微信号:
文章图片
推荐阅读
- 宽容谁
- 我要做大厨
- 增长黑客的海盗法则
- 画画吗()
- 2019-02-13——今天谈梦想()
- 远去的风筝
- 三十年后的广场舞大爷
- 叙述作文
- 20190302|20190302 复盘翻盘
- 学无止境,人生还很长