DL4J实战之二(鸢尾花分类)
欢迎访问我的GitHub
https://github.com/zq2599/blog_demos
内容:所有原创文章分类汇总及配套源码,涉及Java、Docker、Kubernetes、DevOPS等;
本篇概览
- 本文是《DL4J》实战的第二篇,前面做好了准备工作,接下来进入正式实战,本篇内容是经典的入门例子:鸢尾花分类
- 下图是一朵鸢尾花,我们可以测量到它的四个特征:花瓣(petal)的宽和高,花萼(sepal)的 宽和高:
文章图片
- 鸢尾花有三种:Setosa、Versicolor、Virginica
- 今天的实战是用前馈神经网络Feed-Forward Neural Network (FFNN)就行鸢尾花分类的模型训练和评估,在拿到150条鸢尾花的特征和分类结果后,我们先训练出模型,再评估模型的效果:
文章图片
源码下载
- 本篇实战中的完整源码可在GitHub下载到,地址和链接信息如下表所示(https://github.com/zq2599/blo...):
名称 | 链接 | 备注 |
---|---|---|
项目主页 | https://github.com/zq2599/blo... | 该项目在GitHub上的主页 |
git仓库地址(https) | https://github.com/zq2599/blo... | 该项目源码的仓库地址,https协议 |
git仓库地址(ssh) | git@github.com:zq2599/blog_demos.git | 该项目源码的仓库地址,ssh协议 |
- 这个git项目中有多个文件夹,《DL4J实战》系列的源码在dl4j-tutorials文件夹下,如下图红框所示:
文章图片
- dl4j-tutorials文件夹下有多个子工程,本次实战代码在dl4j-tutorials目录下,如下图红框:
文章图片
编码
- 在dl4j-tutorials工程下新建子工程classifier-iris,其pom.xml如下:
dlfj-tutorials
com.bolingcavalry
1.0-SNAPSHOT
4.0.0 classifier-iris8
8
com.bolingcavalry
commons
${project.version}
org.projectlombok
lombok
org.nd4j
${nd4j.backend}
ch.qos.logback
logback-classic
- 上述pom.xml有一处需要注意的地方,就是${nd4j.backend}参数的值,该值在决定了后端线性代数计算是用CPU还是GPU,本篇为了简化操作选择了CPU(因为个人的显卡不同,代码里无法统一),对应的配置就是nd4j-native;
- 源码全部在Iris.java文件中,并且代码中已添加详细注释,就不再赘述了:
package com.bolingcavalry.classifier;
import com.bolingcavalry.commons.utils.DownloaderUtility;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
/**
* @author will (zq2599@gmail.com)
* @version 1.0
* @description: 鸢尾花训练
* @date 2021/6/13 17:30
*/
@SuppressWarnings("DuplicatedCode")
@Slf4j
public class Iris {public static void main(String[] args) throwsException {//第一阶段:准备// 跳过的行数,因为可能是表头
int numLinesToSkip = 0;
// 分隔符
char delimiter = ',';
// CSV读取工具
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
// 下载并解压后,得到文件的位置
String dataPathLocal = DownloaderUtility.IRISDATA.Download();
log.info("鸢尾花数据已下载并解压至 : {}", dataPathLocal);
// 读取下载后的文件
recordReader.initialize(new FileSplit(new File(dataPathLocal,"iris.txt")));
// 每一行的内容大概是这样的:5.1,3.5,1.4,0.2,0
// 一共五个字段,从零开始算的话,标签在第四个字段
int labelIndex = 4;
// 鸢尾花一共分为三类
int numClasses = 3;
// 一共150个样本
int batchSize = 150;
//Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)// 加载到数据集迭代器中
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = https://www.it610.com/article/iterator.next();
// 洗牌(打乱顺序)
allData.shuffle();
// 设定比例,150个样本中,百分之六十五用于训练
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
//Use 65% of data for training// 训练用的数据集
DataSet trainingData = testAndTrain.getTrain();
// 验证用的数据集
DataSet testData = testAndTrain.getTest();
// 指定归一化器:独立地将每个特征值(和可选的标签值)归一化为0平均值和1的标准差。
DataNormalization normalizer = new NormalizerStandardize();
// 先拟合
normalizer.fit(trainingData);
// 对训练集做归一化
normalizer.transform(trainingData);
// 对测试集做归一化
normalizer.transform(testData);
// 每个鸢尾花有四个特征
final int numInputs = 4;
// 共有三种鸢尾花
int outputNum = 3;
// 随机数种子
long seed = 6;
//第二阶段:训练
log.info("开始配置...");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.activation(Activation.TANH)// 激活函数选用标准的tanh(双曲正切)
.weightInit(WeightInit.XAVIER)// 权重初始化选用XAVIER:均值 0, 方差为 2.0/(fanIn + fanOut)的高斯分布
.updater(new Sgd(0.1))// 更新器,设置SGD学习速率调度器
.l2(1e-4)// L2正则化配置
.list()// 配置多层网络
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)// 隐藏层
.build())
.layer(new DenseLayer.Builder().nIn(3).nOut(3)// 隐藏层
.build())
.layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)// 损失函数:负对数似然
.activation(Activation.SOFTMAX)// 输出层指定激活函数为:SOFTMAX
.nIn(3).nOut(outputNum).build())
.build();
// 模型配置
MultiLayerNetwork model = new MultiLayerNetwork(conf);
// 初始化
model.init();
// 每一百次迭代打印一次分数(损失函数的值)
model.setListeners(new ScoreIterationListener(100));
long startTime = System.currentTimeMillis();
log.info("开始训练");
// 训练
for(int i=0;
i<1000;
i++ ) {
model.fit(trainingData);
}
log.info("训练完成,耗时[{}]ms", System.currentTimeMillis()-startTime);
// 第三阶段:评估// 在测试集上评估模型
Evaluation eval = new Evaluation(numClasses);
INDArray output = model.output(testData.getFeatures());
eval.eval(testData.getLabels(), output);
log.info("评估结果如下\n" + eval.stats());
}
}
- 编码完成后,运行main方法,可见顺利完成训练并输出了评估结果,还有混淆矩阵用于辅助分析:
文章图片
- 至此,咱们的第一个实战就完成了,通过经典实例体验的DL4J训练和评估的常规步骤,对重要API也有了初步认识,接下来会继续实战,接触到更多的经典实例;
- Java系列
- Spring系列
- Docker系列
- kubernetes系列
- 数据库+中间件系列
- DevOps系列
微信搜索「程序员欣宸」,我是欣宸,期待与您一同畅游Java世界...
https://github.com/zq2599/blog_demos
推荐阅读
- 《机器学习实战》高清中文版PDF英文版PDF+源代码下载
- --木木--|--木木-- 第二课作业#翼丰会(每日一淘6+1实战裂变被动引流# 6+1模式)
- Java内存泄漏分析系列之二(jstack生成的Thread|Java内存泄漏分析系列之二:jstack生成的Thread Dump日志结构解析)
- 2020-07-29《吴军·阅读与写作50讲》24实战才能转化效能
- 暗能量时代之二
- Python实战计划学习笔记(9)为大规模爬取准备
- 韵达基于云原生的业务中台建设 | 实战派
- 【V课会】第3季-30天小学思维导图实战营
- 【思维导图实战派】刻意练习计划“遇见……”|【思维导图实战派】刻意练习计划“遇见……” 1/300 人教版数学五下第三单元《正方体和长方体的认识》
- OpenCV|OpenCV-Python实战(18)——深度学习简介与入门示例