自定义数据集 作为java深度学习框架,进行深度学习的时候,首先重要的是数据集,只有有了数据,才可以对自己的模型进行训练。
我这里采用的是人脸检测,这是数据集 CelebA 点击下载
DJL基础依赖
ai.djl
api
0.9.0
ai.djl
basicdataset
0.9.0
ai.djl.pytorch
pytorch-model-zoo
0.9.0
ai.djl.pytorch
pytorch-engine
0.9.0
ai.djl.pytorch
pytorch-native-auto
1.7.1
这是DJL的基础库依赖,只有有了这些,我们才可以进行深度学习,底层为aws开发的C++,和C的调用,但是由于是aws维护,拿来用,它不香吗?
数据集创建 这是官方的介绍
DJL中的数据集代表原始数据和加载过程。RandomAccessDataset实现了Dataset接口,并提供了全面的数据加载功能。RandomAccessDataset还是支持使用索引对数据进行随机访问的基本数据集。您可以通过扩展RandomAccessDataset轻松自定义自己的数据集我这里介绍的主要是创建自定义数据集
官网地址
【深度学习|DJL——java深度学习框架学习笔记——自定义数据集】创建自定义数据集,需要集成RandomAccessDataset对象,重写他的一些方法
废话不多说,直接上代码
package com.face.demo.utlis;
import ai.djl.Application;
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.Resource;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslateException;
import ai.djl.util.JsonUtils;
import ai.djl.util.Progress;
import cn.hutool.core.io.file.FileReader;
import com.face.demo.pojo.FaceInfo;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.reflect.TypeToken;
import com.sun.imageio.plugins.common.ImageUtil;
import org.apache.commons.csv.CSVRecord;
import java.io.*;
import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
public class FaceDataSetextends RandomAccessDataset {private static final String VERSION = "1.0";
private static final String ARTIFACT_ID = "banana";
private final Usage usage;
private final Image.Flag flag;
private final List imagePaths;
private final List labels;
private final Resource resource;
private boolean prepared;
public FaceDataSet(FaceDataSet.Builder builder) {
super(builder);
this.usage = builder.usage;
this.flag = builder.flag;
this.imagePaths = new ArrayList();
this.labels = new ArrayList();
MRL mrl = MRL.dataset(Application.CV.ANY, builder.groupId, builder.artifactId);
this.resource = new Resource(builder.repository, mrl, "1.0");
}public static FaceDataSet.Builder builder() {
return new FaceDataSet.Builder();
}@Override
public Record get(NDManager manager, long index) throws IOException {
int idx = Math.toIntExact(index);
NDList d = new NDList(new NDArray[]{ImageFactory.getInstance().fromFile((Path) this.imagePaths.get(idx)).toNDArray(manager, this.flag)});
NDArray label = manager.create((float[]) this.labels.get(idx));
NDList l = new NDList(new NDArray[]{label.reshape((new Shape(new long[]{1L})).addAll(label.getShape()))});
return new Record(d, l);
}@Override
protected long availableSize() {
return (long) this.imagePaths.size();
}@Override
public void prepare(Progress progress) throws IOException, TranslateException {if (!this.prepared) {Path usagePath = Paths.get("C:\\Users\\mzp\\Documents\\img_celeba.7z\\img_celeba\\img_celeba");
FileReader fileReader = new FileReader("C:\\Users\\mzp\\Documents\\Anno\\list_bbox_celeba.txt");
List> strings = fileReader.readLines();
strings.remove(0);
strings.remove(0);
strings.forEach((s) -> {
String[] s1 = s.split("\\s+");
FaceInfo faceInfo = new FaceInfo(s1);
float[] labelArray = new float[5];
labelArray[0] = 0.0f;
float[] normalized = Normalized(faceInfo);
labelArray[1] = (Float) normalized[0];
labelArray[2] = (Float) normalized[1];
labelArray[3] = (Float) normalized[2];
labelArray[4] = (Float) normalized[3];
this.imagePaths.add(usagePath.resolve(faceInfo.getImage_id()));
this.labels.add(labelArray);
});
this.prepared = true;
}}public static final class Builder extends BaseBuilder {
Repository repository;
String groupId;
String artifactId;
Usage usage;
Image.Flag flag;
Builder() {
this.repository = BasicDatasets.REPOSITORY;
this.groupId = "ai.djl.basicdataset";
this.artifactId = "face";
this.usage = Usage.TRAIN;
this.flag = Image.Flag.COLOR;
}public FaceDataSet.Builder self() {
return this;
}public FaceDataSet.Builder optUsage(Usage usage) {
this.usage = usage;
return this.self();
}public FaceDataSet.Builder optRepository(Repository repository) {
this.repository = repository;
return this.self();
}public FaceDataSet.Builder optGroupId(String groupId) {
this.groupId = groupId;
return this;
}public FaceDataSet.Builder optArtifactId(String artifactId) {
if (artifactId.contains(":")) {
String[] tokens = artifactId.split(":");
this.groupId = tokens[0];
this.artifactId = tokens[1];
} else {
this.artifactId = artifactId;
}return this;
}public FaceDataSet.Builder optFlag(Image.Flag flag) {
this.flag = flag;
return this.self();
}public FaceDataSet build() {
if (this.pipeline == null) {
this.pipeline = new Pipeline(new Transform[]{new ToTensor()});
}return new FaceDataSet(this);
}
}publicfloat[] Normalized(FaceInfo faceInfo) {
File file = new File(faceInfo.getImageURL());
try {
FileInputStream fileInputStream = new FileInputStream(file);
Image image = ImageFactory.getInstance().fromInputStream(fileInputStream);
float dw = 1.f / image.getWidth();
float dh = 1.f / image.getHeight();
float x_1 = Float.parseFloat(faceInfo.getX_1());
float y_1 = Float.parseFloat(faceInfo.getY_1());
float width = Float.parseFloat(faceInfo.getWidth());
float height = Float.parseFloat(faceInfo.getHeight());
float x = (x_1 + y_1) / 2.0f;
float y = (width + height) / 2.0f;
float w = y_1 - x_1;
float h = height - width;
x = x * dw;
w = w * dw;
y = y * dh;
h = h * dh;
float[] floats = new float[4];
floats[0] = x;
floats[1] = w;
floats[2] = y;
floats[3] = h;
return floats;
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
}
推荐阅读
- 金融中的AI和机器学习(在银行,保险,投资以及用户体验中的用例)
- 大数据|UCLA李婧翌(女性最不需要做的就是「怀疑自己」| 妇女节特辑)
- #|基于蒙特卡洛法的规模化电动汽车充电负荷预测(Python&Matlab实现)
- 经典论文解读|AlexNet经典论文解读
- pytorch深度学习|pytorch中张量的维度变换,torch.squeeze()、torch.unsqueeze()函数
- AIRX|数据科学家需要了解的15个Python库
- 脉冲神经网络|脉冲神经网络-基于IAF神经元的手写数字识别
- GAN|基于GAN的图像修复--论文笔记
- 神经网络|2020 年,苹果的 AI 还有创新吗()