深度学习|DJL——java深度学习框架学习笔记——自定义数据集

自定义数据集 作为java深度学习框架,进行深度学习的时候,首先重要的是数据集,只有有了数据,才可以对自己的模型进行训练。
我这里采用的是人脸检测,这是数据集 CelebA 点击下载
DJL基础依赖

ai.djl api 0.9.0 ai.djl basicdataset 0.9.0 ai.djl.pytorch pytorch-model-zoo 0.9.0 ai.djl.pytorch pytorch-engine 0.9.0 ai.djl.pytorch pytorch-native-auto 1.7.1

这是DJL的基础库依赖,只有有了这些,我们才可以进行深度学习,底层为aws开发的C++,和C的调用,但是由于是aws维护,拿来用,它不香吗?
数据集创建 这是官方的介绍
DJL中的数据集代表原始数据和加载过程。RandomAccessDataset实现了Dataset接口,并提供了全面的数据加载功能。RandomAccessDataset还是支持使用索引对数据进行随机访问的基本数据集。您可以通过扩展RandomAccessDataset轻松自定义自己的数据集
官网地址
我这里介绍的主要是创建自定义数据集
【深度学习|DJL——java深度学习框架学习笔记——自定义数据集】创建自定义数据集,需要集成RandomAccessDataset对象,重写他的一些方法
废话不多说,直接上代码
package com.face.demo.utlis; import ai.djl.Application; import ai.djl.basicdataset.BasicDatasets; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.modality.cv.output.Rectangle; import ai.djl.modality.cv.transform.ToTensor; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; import ai.djl.repository.Resource; import ai.djl.training.dataset.RandomAccessDataset; import ai.djl.training.dataset.Record; import ai.djl.translate.Pipeline; import ai.djl.translate.Transform; import ai.djl.translate.TranslateException; import ai.djl.util.JsonUtils; import ai.djl.util.Progress; import cn.hutool.core.io.file.FileReader; import com.face.demo.pojo.FaceInfo; import com.google.gson.JsonElement; import com.google.gson.JsonObject; import com.google.gson.reflect.TypeToken; import com.sun.imageio.plugins.common.ImageUtil; import org.apache.commons.csv.CSVRecord; import java.io.*; import java.lang.reflect.Type; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Map; public class FaceDataSetextends RandomAccessDataset {private static final String VERSION = "1.0"; private static final String ARTIFACT_ID = "banana"; private final Usage usage; private final Image.Flag flag; private final List imagePaths; private final List labels; private final Resource resource; private boolean prepared; public FaceDataSet(FaceDataSet.Builder builder) { super(builder); this.usage = builder.usage; this.flag = builder.flag; this.imagePaths = new ArrayList(); this.labels = new ArrayList(); MRL mrl = MRL.dataset(Application.CV.ANY, builder.groupId, builder.artifactId); this.resource = new Resource(builder.repository, mrl, "1.0"); }public static FaceDataSet.Builder builder() { return new FaceDataSet.Builder(); }@Override public Record get(NDManager manager, long index) throws IOException { int idx = Math.toIntExact(index); NDList d = new NDList(new NDArray[]{ImageFactory.getInstance().fromFile((Path) this.imagePaths.get(idx)).toNDArray(manager, this.flag)}); NDArray label = manager.create((float[]) this.labels.get(idx)); NDList l = new NDList(new NDArray[]{label.reshape((new Shape(new long[]{1L})).addAll(label.getShape()))}); return new Record(d, l); }@Override protected long availableSize() { return (long) this.imagePaths.size(); }@Override public void prepare(Progress progress) throws IOException, TranslateException {if (!this.prepared) {Path usagePath = Paths.get("C:\\Users\\mzp\\Documents\\img_celeba.7z\\img_celeba\\img_celeba"); FileReader fileReader = new FileReader("C:\\Users\\mzp\\Documents\\Anno\\list_bbox_celeba.txt"); List> strings = fileReader.readLines(); strings.remove(0); strings.remove(0); strings.forEach((s) -> { String[] s1 = s.split("\\s+"); FaceInfo faceInfo = new FaceInfo(s1); float[] labelArray = new float[5]; labelArray[0] = 0.0f; float[] normalized = Normalized(faceInfo); labelArray[1] = (Float) normalized[0]; labelArray[2] = (Float) normalized[1]; labelArray[3] = (Float) normalized[2]; labelArray[4] = (Float) normalized[3]; this.imagePaths.add(usagePath.resolve(faceInfo.getImage_id())); this.labels.add(labelArray); }); this.prepared = true; }}public static final class Builder extends BaseBuilder { Repository repository; String groupId; String artifactId; Usage usage; Image.Flag flag; Builder() { this.repository = BasicDatasets.REPOSITORY; this.groupId = "ai.djl.basicdataset"; this.artifactId = "face"; this.usage = Usage.TRAIN; this.flag = Image.Flag.COLOR; }public FaceDataSet.Builder self() { return this; }public FaceDataSet.Builder optUsage(Usage usage) { this.usage = usage; return this.self(); }public FaceDataSet.Builder optRepository(Repository repository) { this.repository = repository; return this.self(); }public FaceDataSet.Builder optGroupId(String groupId) { this.groupId = groupId; return this; }public FaceDataSet.Builder optArtifactId(String artifactId) { if (artifactId.contains(":")) { String[] tokens = artifactId.split(":"); this.groupId = tokens[0]; this.artifactId = tokens[1]; } else { this.artifactId = artifactId; }return this; }public FaceDataSet.Builder optFlag(Image.Flag flag) { this.flag = flag; return this.self(); }public FaceDataSet build() { if (this.pipeline == null) { this.pipeline = new Pipeline(new Transform[]{new ToTensor()}); }return new FaceDataSet(this); } }publicfloat[] Normalized(FaceInfo faceInfo) { File file = new File(faceInfo.getImageURL()); try { FileInputStream fileInputStream = new FileInputStream(file); Image image = ImageFactory.getInstance().fromInputStream(fileInputStream); float dw = 1.f / image.getWidth(); float dh = 1.f / image.getHeight(); float x_1 = Float.parseFloat(faceInfo.getX_1()); float y_1 = Float.parseFloat(faceInfo.getY_1()); float width = Float.parseFloat(faceInfo.getWidth()); float height = Float.parseFloat(faceInfo.getHeight()); float x = (x_1 + y_1) / 2.0f; float y = (width + height) / 2.0f; float w = y_1 - x_1; float h = height - width; x = x * dw; w = w * dw; y = y * dh; h = h * dh; float[] floats = new float[4]; floats[0] = x; floats[1] = w; floats[2] = y; floats[3] = h; return floats; } catch (FileNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } return null; } }

    推荐阅读