深度学习|在Android上部署TF目标检测模型

在移动设备上部署机器学习模型是ML即将开始的新阶段。

目标检测模型,已经与语音识别、图像分类等模型一起应用于移动设备。这些模型通常运行在支持GPU的计算机上,部署在移动设备上时也有大量用例。
为了演示如何将ML模型,特别是对象检测模型引入Android的端到端示例,我们将使用Victor Dibia的手检测模型进行演示,该模型来自victordibia/handtracking repo。
https://github.com/victordibia/handtracking
该模型可以从图像中检测人手,并使用TensorFlow对象检测API制作。我们将使用Victor Dibia的repo中经过训练的模型,并将其转换为TensorFlow Lite(TFLite)格式,该格式可用于在Android(甚至iOS、Raspberry Pi)上运行该模型。
接下来,我们将转到Android应用程序,创建运行模型所需的所有必要类/方法,并在实时摄像机上显示其预测(边界框)。
让我们开始吧!
目录

  • 将模型转换为TFLite
    • 1.设置TF目标检测API
    • 2.将检查点转换为图
    • 3.将图转换为TFLite缓冲区
  • 在Android中集成TFLite模型
    • 【深度学习|在Android上部署TF目标检测模型】1.添加CameraX、Coroutines 和TF Lite的依赖项
    • 2.初始化CameraX和ImageAnalysis.Analyzer
    • 3.实现手部检测模型
    • 4.绘制边界框
将模型检查点转换为TFLite
我们的第一步是将Victor Dibia的repo中提供的经过训练的模型检查点转换为TensorFlow Lite格式。TensorFlow Lite提供了一个在Android、iOS和微控制器设备上运行TensorFlow模型的高效网关。为了运行转换脚本,我们需要在机器上设置TensorFlow对象检测API。你也可以使用这个Colab笔记本来执行所有转换。
建议你使用Colab笔记本(尤其是Windows)。
https://github.com/shubham0204/Google_Colab_Notebooks/blob/main/Hand_Tracking_Model_TFLite_Conversion.ipynb
1.设置TF目标检测API TensorFlow对象检测API提供了许多预训练的对象检测模型,这些模型可以在自定义数据集上进行微调,并直接部署到移动、web或云中。我们只需要帮助我们将模型检查点转换为TF Lite缓冲区的转换脚本。
手部检测模型本身是使用TF OD API和TensorFlow 1.x制作的。所以,首先我们需要安装TensorFlow 1.x或TF 1.15.0(1.x系列中的最新版本),然后克隆包含TF OD API的tensorflow/models repo。
# Installing TF 1.15.0 !pip install tensorflow==1.15.0# Cloning the tensorflow/models repo !git clone https://github.com/tensorflow/models# Installing the TF OD API %%bash cd models/research/ protoc object_detection/protos/*.proto --python_out=. cp object_detection/packages/tf1/setup.py . python -m pip install .

此外,我们将克隆Victor Dibia仓库,以获得模型检查点
!git clone https://github.com/victordibia/handtracking

2.将检查点转换为图 现在在models/research/object_detection目录中,你将看到一个Python脚本export_tflite_ssd_graph.py,我们将使用它将模型检查点转换为与TFLite兼容的图。检查点可以在handtracking/model-checkpoint目录中找到。ssd代表‘Single Shot Detector’,这是手部检测模型的体系结构,而mobilenet代表mobilenet(v1或v2)的主干体系结构,这是一种专门用于移动设备的CNN体系结构。
深度学习|在Android上部署TF目标检测模型
文章图片

导出的TFLite图包含固定的输入和输出节点。我们可以在export_tflite_ssd_graph中找到这些节点(或张量)的名称和形状。使用该脚本,我们将把模型检查点转换为一个与TFLite兼容的图,给出三个参数,
  1. pipeline_config_path:指向的路径。包含所用SSD Lite模型配置的配置文件。
  2. trained_checkpoint_prefix:我们希望转换模型检查点的前缀。
  3. max_detections:要预测的边界框的数量。
!python models/research/object_detection/export_tflite_ssd_graph.py \ --pipeline_config_path handtracking/model-checkpoint/ssdlitemobilenetv2/data_ssdlite.config \ --trained_checkpoint_prefix handtracking/model-checkpoint/ssdlitemobilenetv2/out_model.ckpt-19040 \ --output_directory outputs \ --max_detections 10

脚本执行后,我们剩下两个文件,tflite_graph.pb和tflite_graph.pbtxt是与TFLite兼容的图形。
3.将图转换为TFLite缓冲区 现在,我们将使用第二个脚本(或者更准确地说,一个实用程序)将步骤2中生成的图转换为TFLite缓冲区(.tflite)。如TensorFlow 2.x排除了Session和Placeholder的使用,我们不能在这里将图转换为TFLite。这就是我们安装TensorFlow 1.x的原因之一。
我们将使用tflite_convert实用程序将图转换为tflite缓冲区。我们也可以使用tf.lite.TFLiteConverter API,但我们现在将继续使用命令行实用程序。
!tflite_convert \ --graph_def_file=/content/outputs/tflite_graph.pb \ --output_file=/content/outputs/model.tflite \ --output_format=TFLITE \ --input_arrays=normalized_input_image_tensor \ --input_shapes=1,300,300,3 \ --inference_type=FLOAT \ --output_arrays="TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3" \ --allow_custom_ops

执行完成后,你将看到一个outputs目录中的model.tflite文件。为了检查输入/输出形状,我们将使用tf.lite.Interpreter加载TFLite模型,调用.get_input_details()和.get_output_details()分别获取输入和输出详细信息。
提示:使用pprint可以获得漂亮的输出。
import tensorflow as tf import pprintinterpreter = tf.lite.Interpreter( '/content/outputs/model.tflite' ) interpreter.allocate_tensors()pprint.pprint( interpreter.get_input_details()) pprint.pprint( interpreter.get_output_details() )

在Android中集成TFLite模型
一旦我们有了TFLite模型及其输入和输出形状的所有细节,我们就可以在Android应用程序中运行它了。在Android Studio中创建一个新项目,或者可以自由地克隆GitHub repo!
1.添加CameraX、Coroutines 和TF Lite的依赖项 当我们要在实时摄像头上检测手的时候,我们需要在Android应用程序中添加CameraX依赖项。类似地,为了运行TFLite模型,我们需要tensorflow lite依赖项和Kotlin Coroutines依赖项,它们帮助我们异步运行模型。在应用程序级构建中.gradle文件,我们将添加以下依赖项,
plugins { ... }android {...aaptOptions { noCompress "tflite" }...}dependencies { ...// CameraX dependencies implementation "androidx.camera:camera-camera2:1.0.1" implementation "androidx.camera:camera-lifecycle:1.0.1" implementation "androidx.camera:camera-view:1.0.0-alpha28" implementation "androidx.camera:camera-extensions:1.0.0-alpha28"// TensorFlow Lite dependencies implementation 'org.tensorflow:tensorflow-lite:2.4.0' implementation 'org.tensorflow:tensorflow-lite-gpu:2.4.0' implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'// Kotlin Coroutines implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.4.1' implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-android:1.4.1'... }

确保添加了aaptOptions{ noCompress "tflite" },这样系统就不会压缩模型,从而使应用程序的大小变小。现在,为了在我们的应用程序中放置TFLite模型,我们将在app/src/main下创建一个assets文件夹。将TFLite文件(.tflite)粘贴到此文件夹中。
深度学习|在Android上部署TF目标检测模型
文章图片

2.初始化CameraX和ImageAnalysis.Analyzer 我们将使用CameraX软件包中的PreviewView向用户显示实时摄像头提要。在它上面,我们将放置一个名为BoundingBoxOverlay的图,用于在摄影机提要上绘制边界框。我不会在这里讨论实现,但你可以从源代码中学到,
https://proandroiddev.com/realtime-selfie-segmentation-in-android-with-mlkit-38637c8502ba
当我们要预测实时帧数据的边界框时,我们还需要ImageAnalysis.Analyzer对象,该对象从实时摄影机提要返回每一帧。请参阅FrameAnalyzer中的这段代码。
// Image Analyser for performing hand detection on camera frames. class FrameAnalyser( private val handDetectionModel: HandDetectionModel , private val boundingBoxOverlay: BoundingBoxOverlay ) : ImageAnalysis.Analyzer {private var frameBitmap : Bitmap? = null private var isFrameProcessing = falseoverride fun analyze(image: ImageProxy) { // If a frame is being processed, drop the current frame. if ( isFrameProcessing ) { image.close() return } isFrameProcessing = true// Get the `Bitmap` of the current frame ( with corrected rotation ). frameBitmap = BitmapUtils.imageToBitmap( image.image!! , image.imageInfo.rotationDegrees ) image.close()// Configure frameHeight and frameWidth for output2overlay transformation matrix. if ( !boundingBoxOverlay.areDimsInit ) { Logger.logInfo( "Passing dims to overlay..." ) boundingBoxOverlay.frameHeight = frameBitmap!!.height boundingBoxOverlay.frameWidth = frameBitmap!!.width }CoroutineScope( Dispatchers.Main ).launch { runModel( frameBitmap!! ) } }private suspend fun runModel( inputImage : Bitmap ) = withContext( Dispatchers.Default ) { ... }}

BitmapUtils包含一些有用的静态方法来操作Bitmap。isFrameProcessing是一个布尔变量,用于确定是否必须删除传入帧或将其传递给模型。正如你所观察到的,我们在一个CoroutineScope中运行模型,因此当模型产生推断时,你不会观察到任何滞后。
3.实现手部检测模型 接下来,我们将创建一个名为HandDetectionModel的类,该类将处理所有TFLite操作,并返回给定图像(作为位图)的预测。
// Helper class for Hand detection TFLite model class HandDetectionModel( context: Context ) {// I/O details for the hand detection model. // Refer to the comments of this script -> // https://github.com/tensorflow/models/blob/master/research/object_detection/export_tflite_ssd_graph.py // For quantization, use the tflite_convert utility as described in the conversion notebook ( README ). private val modelInputImageDim = 300 private val isQuantized = false private val maxDetections = 10 private val boundingBoxesTensorShape = intArrayOf( 1 , maxDetections , 4 ) // [ 1 , 10 , 4 ] private val confidenceScoresTensorShape = intArrayOf( 1 , maxDetections ) // [ 1 , 10 ] private val classesTensorShape = intArrayOf( 1 , maxDetections ) // [ 1 , 10 ] private val numBoxesTensorShape = intArrayOf( 1 ) // [ 1 , ] // Input tensor processor for quantized and non-quantized versions of the model. private val inputImageProcessorQuantized = ImageProcessor.Builder() .add( ResizeOp( modelInputImageDim , modelInputImageDim , ResizeOp.ResizeMethod.BILINEAR ) ) .add( CastOp( DataType.FLOAT32 ) ) .build() private val inputImageProcessorNonQuantized = ImageProcessor.Builder() .add( ResizeOp( modelInputImageDim , modelInputImageDim , ResizeOp.ResizeMethod.BILINEAR ) ) .add( NormalizeOp( 128.5f , 128.5f ) ) .build()// See app/src/main/assets for the TFLite model. private val modelName = "model.tflite" private val numThreads = 4 private var interpreter : Interpreter // Confidence threshold for NMS private val outputConfidenceThreshold = 0.9f...

我们将在上述片段中分别理解每个术语,
  1. modelImageInputDim是我们模型的输入图像的大小。我们的模型可以拍摄300*300的图像。
  2. maxDetections代表我们的模型所做预测的最大数量。它决定了boundingBoxesTensorShape、confidenceScoresTensorShape、classesTensorShape和numTensorShape的形状。
  3. outputConfidenceThreshold用于过滤模型做出的预测。这不是NMS,但我们只接受分数大于此阈值的框。
  4. inputImageProcessorQuantized和inputImageProcessorNonQuantized是TensorOperator的实例,它将给定图像的大小调整为modelImageInputDim*modelInputImageDim。对于量化模型,我们用平均值和标准偏差都等于127.5的标准化给定图像。
现在,我们将实现一个run()方法,它将获取位图图像,并以List的形式输出边界框。Prediction是一个包含预测数据的类,如置信度得分和边界框坐标。
// Store the width and height of the input frames as they will be used for future transformations. inputFrameWidth = inputImage.width inputFrameHeight = inputImage.heightvar tensorImage = TensorImage.fromBitmap( inputImage ) tensorImage = if ( isQuantized ) { inputImageProcessorQuantized.process( tensorImage ) } else { inputImageProcessorNonQuantized.process( tensorImage ) }val confidenceScores = TensorBuffer.createFixedSize( confidenceScoresTensorShape , DataType.FLOAT32 ) val boundingBoxes = TensorBuffer.createFixedSize( boundingBoxesTensorShape , DataType.FLOAT32 ) val classes = TensorBuffer.createFixedSize( classesTensorShape , DataType.FLOAT32 ) val numBoxes = TensorBuffer.createFixedSize( numBoxesTensorShape , DataType.FLOAT32 ) val outputs = mapOf( 0 to boundingBoxes.buffer , 1 to classes.buffer , 2 to confidenceScores.buffer , 3 to numBoxes.buffer )val t1 = System.currentTimeMillis() interpreter.runForMultipleInputsOutputs( arrayOf(tensorImage.buffer), outputs ) Logger.logInfo( "Model inference time -> ${System.currentTimeMillis() - t1} ms." )return processOutputs( confidenceScores , boundingBoxes )

confidenceScores , boundingBoxes , classesnumBoxes是保存模型输出的四个张量。processOutputs方法将过滤边界框,并仅返回置信度得分大于阈值的框。
private fun processOutputs( scores : TensorBuffer , boundingBoxes : TensorBuffer ) : List
{ // Flattened version of array of shape [ 1 , maxDetections ] ( size = maxDetections ) val scoresFloatArray = scores.floatArray // Flattened version of array of shape [ 1 , maxDetections , 4 ] ( size = maxDetections * 4 ) val boxesFloatArray = boundingBoxes.floatArray val predictions = ArrayList
() for ( i in boxesFloatArray.indices step 4 ) { // Store predictions which have a confidence > threshold if ( scoresFloatArray[ i / 4 ] >= filterThreshold ) { predictions.add( Prediction( getRect( boxesFloatArray.sliceArray( i..i+3 )) , scoresFloatArray[ i / 4 ] ) ) } } return predictions.toList() }// Transform the normalized bounding box coordinates relative to the input frames. private fun getRect( coordinates : FloatArray ) : Rect { return Rect( max( (coordinates[ 1 ] * inputFrameWidth).toInt() , 1 ), max( (coordinates[ 0 ] * inputFrameHeight).toInt() , 1 ), min( (coordinates[ 3 ] * inputFrameWidth).toInt() , inputFrameWidth ), min( (coordinates[ 2 ] * inputFrameHeight).toInt() , inputFrameHeight ) ) }

4.绘制边界框 一旦我们收到了边界框,我们会像使用OpenCV一样,将它们绘制。我们将创建一个新的类BoundingBoxOverlay,并将其添加到activity_main.xml。
class BoundingBoxOverlay(context : Context, attributeSet : AttributeSet) : SurfaceView( context , attributeSet ) , SurfaceHolder.Callback {// Variables used to compute output2overlay transformation matrix // These are assigned in FrameAnalyser.kt var areDimsInit = false var frameHeight = 0 var frameWidth = 0// This var is assigned in FrameAnalyser.kt var handBoundingBoxes: List
? = null// This var is assigned in MainActivity.kt var isFrontCameraOn = falseprivate var output2OverlayTransform: Matrix = Matrix() private val boxPaint = Paint().apply { color = Color.YELLOW style = Paint.Style.STROKE strokeWidth = 16f } private val textPaint = Paint().apply { strokeWidth = 2.0f textSize = 32f color = Color.YELLOW }private val displayMetrics = DisplayMetrics()...override fun onDraw(canvas: Canvas?) { if ( handBoundingBoxes == null ) { return } if (!areDimsInit) { ... } else { for ( prediction in handBoundingBoxes!! ) { val rect = prediction.boundingBox.toRectF() output2OverlayTransform.mapRect( rect ) canvas?.drawRoundRect( rect , 16f, 16f, boxPaint ) canvas?.drawText( prediction.confidence.toString(), rect.centerX(), rect.centerY(), textPaint ) } } }}

就这些!我们刚刚在Android应用程序中实现了一个手检测器!你可以在查看所有代码后运行该应用程序。
深度学习|在Android上部署TF目标检测模型
文章图片

感谢阅读!
☆ END ☆
如果看到这里,说明你喜欢这篇文章,请转发、点赞。微信搜索「uncle_pn」,欢迎添加小编微信「 woshicver」,每日朋友圈更新一篇高质量博文。
↓扫描二维码添加小编↓
深度学习|在Android上部署TF目标检测模型
文章图片

    推荐阅读