在移动设备上部署机器学习模型是ML即将开始的新阶段。
目标检测模型,已经与语音识别、图像分类等模型一起应用于移动设备。这些模型通常运行在支持GPU的计算机上,部署在移动设备上时也有大量用例。
为了演示如何将ML模型,特别是对象检测模型引入Android的端到端示例,我们将使用Victor Dibia的手检测模型进行演示,该模型来自victordibia/handtracking repo。
https://github.com/victordibia/handtracking
该模型可以从图像中检测人手,并使用TensorFlow对象检测API制作。我们将使用Victor Dibia的repo中经过训练的模型,并将其转换为TensorFlow Lite(TFLite)格式,该格式可用于在Android(甚至iOS、Raspberry Pi)上运行该模型。
接下来,我们将转到Android应用程序,创建运行模型所需的所有必要类/方法,并在实时摄像机上显示其预测(边界框)。
让我们开始吧!
目录
- 将模型转换为TFLite
- 1.设置TF目标检测API
- 2.将检查点转换为图
- 3.将图转换为TFLite缓冲区
- 1.设置TF目标检测API
- 在Android中集成TFLite模型
- 【深度学习|在Android上部署TF目标检测模型】1.添加CameraX、Coroutines 和TF Lite的依赖项
- 2.初始化CameraX和ImageAnalysis.Analyzer
- 3.实现手部检测模型
- 4.绘制边界框
- 【深度学习|在Android上部署TF目标检测模型】1.添加CameraX、Coroutines 和TF Lite的依赖项
我们的第一步是将Victor Dibia的repo中提供的经过训练的模型检查点转换为TensorFlow Lite格式。TensorFlow Lite提供了一个在Android、iOS和微控制器设备上运行TensorFlow模型的高效网关。为了运行转换脚本,我们需要在机器上设置TensorFlow对象检测API。你也可以使用这个Colab笔记本来执行所有转换。
建议你使用Colab笔记本(尤其是Windows)。
https://github.com/shubham0204/Google_Colab_Notebooks/blob/main/Hand_Tracking_Model_TFLite_Conversion.ipynb
1.设置TF目标检测API TensorFlow对象检测API提供了许多预训练的对象检测模型,这些模型可以在自定义数据集上进行微调,并直接部署到移动、web或云中。我们只需要帮助我们将模型检查点转换为TF Lite缓冲区的转换脚本。
手部检测模型本身是使用TF OD API和TensorFlow 1.x制作的。所以,首先我们需要安装TensorFlow 1.x或TF 1.15.0(1.x系列中的最新版本),然后克隆包含TF OD API的tensorflow/models repo。
# Installing TF 1.15.0
!pip install tensorflow==1.15.0# Cloning the tensorflow/models repo
!git clone https://github.com/tensorflow/models# Installing the TF OD API
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf1/setup.py .
python -m pip install .
此外,我们将克隆Victor Dibia仓库,以获得模型检查点
!git clone https://github.com/victordibia/handtracking
2.将检查点转换为图 现在在models/research/object_detection目录中,你将看到一个Python脚本export_tflite_ssd_graph.py,我们将使用它将模型检查点转换为与TFLite兼容的图。检查点可以在handtracking/model-checkpoint目录中找到。ssd代表‘Single Shot Detector’,这是手部检测模型的体系结构,而mobilenet代表mobilenet(v1或v2)的主干体系结构,这是一种专门用于移动设备的CNN体系结构。
文章图片
导出的TFLite图包含固定的输入和输出节点。我们可以在export_tflite_ssd_graph中找到这些节点(或张量)的名称和形状。使用该脚本,我们将把模型检查点转换为一个与TFLite兼容的图,给出三个参数,
- pipeline_config_path:指向的路径。包含所用SSD Lite模型配置的配置文件。
- trained_checkpoint_prefix:我们希望转换模型检查点的前缀。
- max_detections:要预测的边界框的数量。
!python models/research/object_detection/export_tflite_ssd_graph.py \
--pipeline_config_path handtracking/model-checkpoint/ssdlitemobilenetv2/data_ssdlite.config \
--trained_checkpoint_prefix handtracking/model-checkpoint/ssdlitemobilenetv2/out_model.ckpt-19040 \
--output_directory outputs \
--max_detections 10
脚本执行后,我们剩下两个文件,tflite_graph.pb和tflite_graph.pbtxt是与TFLite兼容的图形。
3.将图转换为TFLite缓冲区 现在,我们将使用第二个脚本(或者更准确地说,一个实用程序)将步骤2中生成的图转换为TFLite缓冲区(.tflite)。如TensorFlow 2.x排除了Session和Placeholder的使用,我们不能在这里将图转换为TFLite。这就是我们安装TensorFlow 1.x的原因之一。
我们将使用tflite_convert实用程序将图转换为tflite缓冲区。我们也可以使用tf.lite.TFLiteConverter API,但我们现在将继续使用命令行实用程序。
!tflite_convert \
--graph_def_file=/content/outputs/tflite_graph.pb \
--output_file=/content/outputs/model.tflite \
--output_format=TFLITE \
--input_arrays=normalized_input_image_tensor \
--input_shapes=1,300,300,3 \
--inference_type=FLOAT \
--output_arrays="TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3" \
--allow_custom_ops
执行完成后,你将看到一个outputs目录中的model.tflite文件。为了检查输入/输出形状,我们将使用tf.lite.Interpreter加载TFLite模型,调用.get_input_details()和.get_output_details()分别获取输入和输出详细信息。
提示:使用pprint可以获得漂亮的输出。
import tensorflow as tf
import pprintinterpreter = tf.lite.Interpreter( '/content/outputs/model.tflite' )
interpreter.allocate_tensors()pprint.pprint( interpreter.get_input_details())
pprint.pprint( interpreter.get_output_details() )
在Android中集成TFLite模型
一旦我们有了TFLite模型及其输入和输出形状的所有细节,我们就可以在Android应用程序中运行它了。在Android Studio中创建一个新项目,或者可以自由地克隆GitHub repo!
1.添加CameraX、Coroutines 和TF Lite的依赖项 当我们要在实时摄像头上检测手的时候,我们需要在Android应用程序中添加CameraX依赖项。类似地,为了运行TFLite模型,我们需要tensorflow lite依赖项和Kotlin Coroutines依赖项,它们帮助我们异步运行模型。在应用程序级构建中.gradle文件,我们将添加以下依赖项,
plugins {
...
}android {...aaptOptions {
noCompress "tflite"
}...}dependencies {
...// CameraX dependencies
implementation "androidx.camera:camera-camera2:1.0.1"
implementation "androidx.camera:camera-lifecycle:1.0.1"
implementation "androidx.camera:camera-view:1.0.0-alpha28"
implementation "androidx.camera:camera-extensions:1.0.0-alpha28"// TensorFlow Lite dependencies
implementation 'org.tensorflow:tensorflow-lite:2.4.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.4.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'// Kotlin Coroutines
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.4.1'
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-android:1.4.1'...
}
确保添加了aaptOptions{ noCompress "tflite" },这样系统就不会压缩模型,从而使应用程序的大小变小。现在,为了在我们的应用程序中放置TFLite模型,我们将在app/src/main下创建一个assets文件夹。将TFLite文件(.tflite)粘贴到此文件夹中。
文章图片
2.初始化CameraX和ImageAnalysis.Analyzer 我们将使用CameraX软件包中的PreviewView向用户显示实时摄像头提要。在它上面,我们将放置一个名为BoundingBoxOverlay的图,用于在摄影机提要上绘制边界框。我不会在这里讨论实现,但你可以从源代码中学到,
https://proandroiddev.com/realtime-selfie-segmentation-in-android-with-mlkit-38637c8502ba
当我们要预测实时帧数据的边界框时,我们还需要ImageAnalysis.Analyzer对象,该对象从实时摄影机提要返回每一帧。请参阅FrameAnalyzer中的这段代码。
// Image Analyser for performing hand detection on camera frames.
class FrameAnalyser(
private val handDetectionModel: HandDetectionModel ,
private val boundingBoxOverlay: BoundingBoxOverlay ) : ImageAnalysis.Analyzer {private var frameBitmap : Bitmap? = null
private var isFrameProcessing = falseoverride fun analyze(image: ImageProxy) {
// If a frame is being processed, drop the current frame.
if ( isFrameProcessing ) {
image.close()
return
}
isFrameProcessing = true// Get the `Bitmap` of the current frame ( with corrected rotation ).
frameBitmap = BitmapUtils.imageToBitmap( image.image!! , image.imageInfo.rotationDegrees )
image.close()// Configure frameHeight and frameWidth for output2overlay transformation matrix.
if ( !boundingBoxOverlay.areDimsInit ) {
Logger.logInfo( "Passing dims to overlay..." )
boundingBoxOverlay.frameHeight = frameBitmap!!.height
boundingBoxOverlay.frameWidth = frameBitmap!!.width
}CoroutineScope( Dispatchers.Main ).launch {
runModel( frameBitmap!! )
}
}private suspend fun runModel( inputImage : Bitmap ) = withContext( Dispatchers.Default ) {
...
}}
BitmapUtils包含一些有用的静态方法来操作Bitmap。isFrameProcessing是一个布尔变量,用于确定是否必须删除传入帧或将其传递给模型。正如你所观察到的,我们在一个CoroutineScope中运行模型,因此当模型产生推断时,你不会观察到任何滞后。
3.实现手部检测模型 接下来,我们将创建一个名为HandDetectionModel的类,该类将处理所有TFLite操作,并返回给定图像(作为位图)的预测。
// Helper class for Hand detection TFLite model
class HandDetectionModel( context: Context ) {// I/O details for the hand detection model.
// Refer to the comments of this script ->
// https://github.com/tensorflow/models/blob/master/research/object_detection/export_tflite_ssd_graph.py
// For quantization, use the tflite_convert utility as described in the conversion notebook ( README ).
private val modelInputImageDim = 300
private val isQuantized = false
private val maxDetections = 10
private val boundingBoxesTensorShape = intArrayOf( 1 , maxDetections , 4 ) // [ 1 , 10 , 4 ]
private val confidenceScoresTensorShape = intArrayOf( 1 , maxDetections ) // [ 1 , 10 ]
private val classesTensorShape = intArrayOf( 1 , maxDetections ) // [ 1 , 10 ]
private val numBoxesTensorShape = intArrayOf( 1 ) // [ 1 , ]
// Input tensor processor for quantized and non-quantized versions of the model.
private val inputImageProcessorQuantized = ImageProcessor.Builder()
.add( ResizeOp( modelInputImageDim , modelInputImageDim , ResizeOp.ResizeMethod.BILINEAR ) )
.add( CastOp( DataType.FLOAT32 ) )
.build()
private val inputImageProcessorNonQuantized = ImageProcessor.Builder()
.add( ResizeOp( modelInputImageDim , modelInputImageDim , ResizeOp.ResizeMethod.BILINEAR ) )
.add( NormalizeOp( 128.5f , 128.5f ) )
.build()// See app/src/main/assets for the TFLite model.
private val modelName = "model.tflite"
private val numThreads = 4
private var interpreter : Interpreter
// Confidence threshold for NMS
private val outputConfidenceThreshold = 0.9f...
我们将在上述片段中分别理解每个术语,
- modelImageInputDim是我们模型的输入图像的大小。我们的模型可以拍摄300*300的图像。
- maxDetections代表我们的模型所做预测的最大数量。它决定了boundingBoxesTensorShape、confidenceScoresTensorShape、classesTensorShape和numTensorShape的形状。
- outputConfidenceThreshold用于过滤模型做出的预测。这不是NMS,但我们只接受分数大于此阈值的框。
- inputImageProcessorQuantized和inputImageProcessorNonQuantized是TensorOperator的实例,它将给定图像的大小调整为modelImageInputDim*modelInputImageDim。对于量化模型,我们用平均值和标准偏差都等于127.5的标准化给定图像。
// Store the width and height of the input frames as they will be used for future transformations.
inputFrameWidth = inputImage.width
inputFrameHeight = inputImage.heightvar tensorImage = TensorImage.fromBitmap( inputImage )
tensorImage = if ( isQuantized ) {
inputImageProcessorQuantized.process( tensorImage )
}
else {
inputImageProcessorNonQuantized.process( tensorImage )
}val confidenceScores = TensorBuffer.createFixedSize( confidenceScoresTensorShape , DataType.FLOAT32 )
val boundingBoxes = TensorBuffer.createFixedSize( boundingBoxesTensorShape , DataType.FLOAT32 )
val classes = TensorBuffer.createFixedSize( classesTensorShape , DataType.FLOAT32 )
val numBoxes = TensorBuffer.createFixedSize( numBoxesTensorShape , DataType.FLOAT32 )
val outputs = mapOf(
0 to boundingBoxes.buffer ,
1 to classes.buffer ,
2 to confidenceScores.buffer ,
3 to numBoxes.buffer
)val t1 = System.currentTimeMillis()
interpreter.runForMultipleInputsOutputs( arrayOf(tensorImage.buffer), outputs )
Logger.logInfo( "Model inference time -> ${System.currentTimeMillis() - t1} ms." )return processOutputs( confidenceScores , boundingBoxes )
confidenceScores
, boundingBoxes
, classes
和numBoxes
是保存模型输出的四个张量。processOutputs方法将过滤边界框,并仅返回置信度得分大于阈值的框。private fun processOutputs( scores : TensorBuffer ,
boundingBoxes : TensorBuffer ) : List {
// Flattened version of array of shape [ 1 , maxDetections ] ( size = maxDetections )
val scoresFloatArray = scores.floatArray
// Flattened version of array of shape [ 1 , maxDetections , 4 ] ( size = maxDetections * 4 )
val boxesFloatArray = boundingBoxes.floatArray
val predictions = ArrayList()
for ( i in boxesFloatArray.indices step 4 ) {
// Store predictions which have a confidence > threshold
if ( scoresFloatArray[ i / 4 ] >= filterThreshold ) {
predictions.add(
Prediction(
getRect( boxesFloatArray.sliceArray( i..i+3 )) ,
scoresFloatArray[ i / 4 ]
)
)
}
}
return predictions.toList()
}// Transform the normalized bounding box coordinates relative to the input frames.
private fun getRect( coordinates : FloatArray ) : Rect {
return Rect(
max( (coordinates[ 1 ] * inputFrameWidth).toInt() , 1 ),
max( (coordinates[ 0 ] * inputFrameHeight).toInt() , 1 ),
min( (coordinates[ 3 ] * inputFrameWidth).toInt() , inputFrameWidth ),
min( (coordinates[ 2 ] * inputFrameHeight).toInt() , inputFrameHeight )
)
}
4.绘制边界框 一旦我们收到了边界框,我们会像使用OpenCV一样,将它们绘制。我们将创建一个新的类BoundingBoxOverlay,并将其添加到activity_main.xml。
class BoundingBoxOverlay(context : Context, attributeSet : AttributeSet)
: SurfaceView( context , attributeSet ) , SurfaceHolder.Callback {// Variables used to compute output2overlay transformation matrix
// These are assigned in FrameAnalyser.kt
var areDimsInit = false
var frameHeight = 0
var frameWidth = 0// This var is assigned in FrameAnalyser.kt
var handBoundingBoxes: List? = null// This var is assigned in MainActivity.kt
var isFrontCameraOn = falseprivate var output2OverlayTransform: Matrix = Matrix()
private val boxPaint = Paint().apply {
color = Color.YELLOW
style = Paint.Style.STROKE
strokeWidth = 16f
}
private val textPaint = Paint().apply {
strokeWidth = 2.0f
textSize = 32f
color = Color.YELLOW
}private val displayMetrics = DisplayMetrics()...override fun onDraw(canvas: Canvas?) {
if ( handBoundingBoxes == null ) {
return
}
if (!areDimsInit) {
...
}
else {
for ( prediction in handBoundingBoxes!! ) {
val rect = prediction.boundingBox.toRectF()
output2OverlayTransform.mapRect( rect )
canvas?.drawRoundRect( rect , 16f, 16f, boxPaint )
canvas?.drawText(
prediction.confidence.toString(),
rect.centerX(),
rect.centerY(),
textPaint
)
}
}
}}
就这些!我们刚刚在Android应用程序中实现了一个手检测器!你可以在查看所有代码后运行该应用程序。
文章图片
感谢阅读!
☆ END ☆
如果看到这里,说明你喜欢这篇文章,请转发、点赞。微信搜索「uncle_pn」,欢迎添加小编微信「 woshicver」,每日朋友圈更新一篇高质量博文。
↓扫描二维码添加小编↓
文章图片
推荐阅读
- 深度学习|Tensorflow2.x 用Lenet神经网络训练 fashion_mnist数据集 并预测。
- pytorch学习笔记|pytorch-resnet34残差网络理解
- 项目|python之逻辑回归项目实战——信用卡欺诈检测
- 计算机视觉|ResNet结构以及残差块详细分析
- python-opencv|opencv缩放
- 图像处理|5分钟学会,使用opencv进行基本的图像操作—读、写、显示、缩放、裁剪(python语言)
- PyTorch|PyTorch: hook机制
- 算法|2021谷歌年度AI技术总结 | Jeff Dean执笔万字展望人工智能的5大未来趋势!
- 动手学习深度学习|《动手学深度学习》Task04(机器翻译及相关技术+注意力机制与Seq2seq模型+Transformer)