android设备模型,Pytroch模型部署到Android设备

本项目是一个简单的图像分类应用程序,演示了如何使用PyTorch Android API。此应用程序在静态图像上运行TorchScript序列化的TorchVision预训练的resnet18模型,该模型作为Android资产打包在应用程序内部。
1.模型准备
让我们从模型准备开始。如果您熟悉PyTorch,您可能应该已经知道如何训练和保存模型。如果您不这样做,我们将使用预先训练的图像分类模型(Resnet18),该模型包装在TorchVision中。要安装它,请运行以下命令:
pip install torchvision
要序列化模型,可以在HelloWorld应用的根文件夹中使用python 代码:
import torch
import torchvision
model = torchvision.models.resnet18(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("app/src/main/assets/model.pt")
如果一切正常,我们应该拥有我们的模型- model.pt在android应用程序的Assets文件夹中生成。它将被打包为android应用程序内部,asset并且可以在设备上使用。
2.从github克隆
git clone https://github.com/pytorch/android-demo-app.gitcd HelloWorldApp
如果已经安装了Android SDK和Android NDK,则可以使用以下命令将此应用程序安装到连接的android设备或模拟器上:
./gradlew installDebug
我们建议您在Android Studio 3.5.1+中打开此项目。目前,PyTorch Android和演示应用程序使用版本3.5.0的android gradle插件,只有Android Studio版本3.5.1和更高版本才支持。使用Android Studio,您将能够通过Android Studio UI安装Android NDK和Android SDK。
3. Gradle依赖
Pytorch android作为build.gradle中的gradle依赖项添加到项目中:
repositories {
jcenter()
}
dependencies {
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
org.pytorch:pytorch_androidPyTorch Android API的主要依赖项在哪里,包括所有4个Android abis(armeabi-v7a,arm64-v8a,x86,x86_64)的libtorch本机库。此外,在此文档中,您可以找到如何仅针对特定的android abis列表重建它。
org.pytorch:pytorch_android_torchvision-具有实用功能的附加库,用于转换android.media.Image和android.graphics.Bitmap张量。
4.从Android Asset读取图像
所有逻辑都发生在中org.pytorch.helloworld.MainActivity。作为第一步,我们阅读image.jpg了android.graphics.Bitmap使用标准Android API的信息。
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
5.加载TorchScript模块
Module module = Module.load(assetFilePath(this, "model.pt"));
org.pytorch.Module表示torch::jit::script::Module可以使用load指定序列化到文件模型的文件路径的方法加载。
6.准备输入
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
org.pytorch.torchvision.TensorImageUtils是org.pytorch:pytorch_android_torchvision图书馆的一部分。该TensorImageUtils#bitmapToFloat32Tensor方法在创建张量torchvision格式使用android.graphics.Bitmap作为源。
所有经过预训练的模型都希望输入图像以相同的方式归一化,即形状为(3 x H x W)的3通道RGB图像的迷你批,其中H和W至少应为224。加载到的范围内[0, 1],然后使用mean = [0.485, 0.456, 0.406]和进行归一化std = [0.229, 0.224, 0.225]
inputTensor的形状为1x3xHxW,其中H和W分别是位图的高度和宽度。
7.运行推理
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();
org.pytorch.Module.forward方法运行加载的模块的forward方法,并org.pytorch.Tensor使用shape 获得作为outputTensor的结果1x1000。
8.处理结果
使用以下org.pytorch.Tensor.getDataAsFloatArray()方法检索其内容:该方法返回浮点数的java数组,并为每个图像网络类分配分数。
之后,我们只找到具有最高分数的索引,然后从ImageNetClasses.IMAGENET_CLASSES包含所有ImageNet类的数组中检索预测的类名。
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
在以下各节中,您可以找到PyTorch Android API的详细说明,用于更大的演示应用程序的代码演练,API的实现细节,如何从源代码进行自定义和构建。
PYTORCH演示应用程序
我们还创建了另一个更复杂的PyTorch Android演示应用程序,该应用程序从同一github存储库中的摄像头输出和文本分类进行图像分类。
void setupCameraX() {
final PreviewConfig previewConfig = new PreviewConfig.Builder().build();
final Preview preview = new Preview(previewConfig);
preview.setOnPreviewOutputUpdateListener(output -> mTextureView.setSurfaceTexture(output.getSurfaceTexture()));
final ImageAnalysisConfig imageAnalysisConfig =
new ImageAnalysisConfig.Builder()
.setTargetResolution(new Size(224, 224))
.setCallbackHandler(mBackgroundHandler)
.setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
.build();
final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);
imageAnalysis.setAnalyzer(
(image, rotationDegrees) -> {
analyzeImage(image, rotationDegrees);
});
CameraX.bindToLifecycle(this, preview, imageAnalysis);
}
void analyzeImage(android.media.Image, int rotationDegrees)
该analyzeImage方法处理相机输出的位置android.media.Image。
从模型中获得预测分数后,它会找到分数最高的前K个类别,并在用户界面上显示。
语言处理示例
另一个示例是基于LSTM模型的自然语言处理,并在reddit注释数据集上进行了训练。逻辑发生在中TextClassificattionActivity。
结果类名称打包在TorchScript模型中,并在初始模块初始化后立即进行初始化。该模块具有一个get_classesreturn的方法,List[str]可以使用method进行调用Module.runMethod(methodName):
mModule = Module.load(moduleFileAbsoluteFilePath);
IValue getClassesOutput = mModule.runMethod("get_classes");
IValue可以将返回的值转换为IValueusing的java数组,IValue.toList()并使用以下方法处理为字符串数组IValue.toStr():
IValue[] classesListIValue = https://www.it610.com/article/getClassesOutput.toList();
String[] moduleClasses = new String[classesListIValue.length];
int i = 0;
for (IValue iv : classesListIValue) {
moduleClasses[i++] = iv.toStr();
}
输入的文本将转换为带有UTF-8编码的java字节数组。从该字节数组Tensor.fromBlobUnsigned创建张量dtype=uint8。
byte[] bytes = text.getBytes(Charset.forName("UTF-8"));
final long[] shape = new long[]{1, bytes.length};
final Tensor inputTensor = Tensor.fromBlobUnsigned(bytes, shape);
模型的运行推断与前面的示例相似:
Tensor outputTensor = mModule.forward(IValue.from(inputTensor)).toTensor()
【android设备模型,Pytroch模型部署到Android设备】之后,代码处理输出,找到得分最高的类。

    推荐阅读