Tensorflow|Tensorflow Lite Model Maker --- 图像分类篇+源码
TFLite_tutorials
The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow neural-network model to particular input data when deploying this model for on-device ML applications.
解读: 此处我们想要得到的是 .tflite 格式的模型,用于在移动端或者嵌入式设备上进行部署
下表罗列的是 TFLite Model Maker 目前支持的几个任务类型
Supported Tasks | Task Utility |
---|---|
Image Classification: tutorial, api | Classify images into predefined categories. |
Object Detection: tutorial, api | Detect objects in real time. |
Text Classification: tutorial, api | Classify text into predefined categories. |
BERT Question Answer: tutorial, api | Find the answer in a certain context for a given question with BERT. |
Audio Classification: tutorial, api | Classify audio into predefined categories. |
Recommendation: demo, api | Recommend items based on the context information for on-device scenario. |
解读: 如果你要训练的模型不符合上述的任务类型,那么可以先训练 Tensorflow Model 然后再转换成 TFLite
【Tensorflow|Tensorflow Lite Model Maker --- 图像分类篇+源码】想用使用 Tensorflow Lite Model Maker 我们需要先安装:
pip install tflite-model-maker
本质完成的是分类任务
更换不同的模型,看最终的准确率,以及 TFLite 的大小、推断速度、内存占用、CPU占用等
下面的代码片段是用于下载数据集的
image_path = tf.keras.utils.get_file(
'flower_photos.tgz',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')
文章图片
数据集结构如下所示:
flower_photos
|__ daisy
|______ 100080576_f52e8ee070_n.jpg
|______ 14167534527_781ceb1b7a_n.jpg
|______ ...
|__ dandelion
|______ 10043234166_e6dd915111_n.jpg
|______ 1426682852_e62169221f_m.jpg
|______ ...
|__ roses
|______ 102501987_3cdb8e5394_n.jpg
|______ 14982802401_a3dfb22afb.jpg
|______ ...
|__ sunflowers
|______ 12471791574_bb1be83df4.jpg
|______ 15122112402_cafa41934f.jpg
|______ ...
|__ tulips
|______ 13976522214_ccec508fe7.jpg
|______ 14487943607_651e8062a1_m.jpg
|______ ...
加载数据集并切分
data = https://www.it610.com/article/DataLoader.from_folder(image_path)
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)
assert tf.__version__.startswith('2')
判断是否为 '2' 开头
模型训练结果 train_acc = 0.9698, val_acc = 0.9375, test_acc = 0.9210 总体来说符合模型的泛化规律
文章图片
文章图片
import os
import timeimport numpy as np
import tensorflow as tf
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
import matplotlib.pyplot as pltassert tf.__version__.startswith('2')image_path = tf.keras.utils.get_file(
'flower_photos.tgz',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')data = https://www.it610.com/article/DataLoader.from_folder(image_path)
# data = data.gen_dataset(batch_size=1)
train_data, rest_data = data.split(0.8)
# for batch in data.take(1):
#print(batch)
#breakvalidation_data, test_data = rest_data.split(0.5)model = image_classifier.create(train_data, validation_data=https://www.it610.com/article/validation_data,
model_spec=model_spec.get('efficientnet_lite0'), epochs=20)loss, accuracy = model.evaluate(test_data)model.export(export_dir='./testTFlite', export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))start = time.time()
print(model.evaluate_tflite('./testTFlite/model.tflite', test_data))
end = time.time()
print('elapsed time: ', end - start)
从上面的输出日志来看,模型经过量化后,准确率并未有多少损失,量化后的模型大小为 4.0MB(efficientnet_lite0)
从下图来看,是单 cpu 在做推断,test_data 的图片有 367 张,总耗时 273.43s
文章图片
config = QuantizationConfig.for_float16()
model.export(export_dir='./testTFlite', tflite_filename='model_fp16.tflite', quantization_config=config, export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))
如果导出的模型是 fp16 的话,模型大小为 6.8MB(efficientnet_lite0),推断速度是 5.54 s,快了很多
model = image_classifier.create(train_data, validation_data=https://www.it610.com/article/validation_data,
model_spec=model_spec.get('mobilenet_v2'), epochs=20)
将模型切换为 mobilenet_v2,导出的 fp16 模型大小为 4.6MB,推断速度是 4.36 s
inception_v3_spec = image_classifier.ModelSpec(
uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')
inception_v3_spec.input_image_shape = [299, 299]
model = image_classifier.create(train_data, validation_data=https://www.it610.com/article/validation_data,
model_spec=inception_v3_spec, epochs=20)
将模型切换为 inception_v3,导出的 fp16 模型大小为 43.8MB(inception_v3),推断速度是 25.31 s
文章图片
Common Dataset used for tasks.class DataLoader(object):
"""This class provides generic utilities for loading customized domain data that will be used later in model retraining.For different ML problems or tasks, such as image classification, text
classification etc., a subclass is provided to handle task-specific data
loading requirements.
"""def __init__(self, dataset, size):
"""Init function for class `DataLoader`.In most cases, one should use helper functions like `from_folder` to create
an instance of this class.Args:
dataset: A tf.data.Dataset object that contains a potentially large set of
elements, where each element is a pair of (input_data, target). The
`input_data` means the raw input data, like an image, a text etc., while
the `target` means some ground truth of the raw input data, such as the
classification label of the image etc.
size: The size of the dataset. tf.data.Dataset donesn't support a function
to get the length directly since it's lazy-loaded and may be infinite.
"""
self._dataset = dataset
self._size = sizedef gen_dataset(self,
batch_size=1,
is_training=False,
shuffle=False,
input_pipeline_context=None,
preprocess=None,
drop_remainder=False):
"""Generate a shared and batched tf.data.Dataset for training/evaluation.
文章图片
Image dataloaderclass ImageClassifierDataLoader(dataloader.ClassificationDataLoader):
"""DataLoader for image classifier."""@classmethod
def from_folder(cls, filename, shuffle=True):
"""Image analysis for image classification load images with labels.Assume the image data of the same label are in the same subdirectory.Args:
filename: Name of the file.
shuffle: boolean, if shuffle, random shuffle data.Returns:
ImageDataset containing images and labels and other related info.
"""
@classmethod
def from_tfds(cls, name):
"""Loads data from tensorflow_datasets."""
文章图片
ImageNet preprocessingclass Preprocessor(object):
"""Preprocessing for image classification."""def __init__(self,
input_shape,
num_classes,
mean_rgb,
stddev_rgb,
use_augmentation=False):
self.input_shape = input_shape
self.num_classes = num_classes
self.mean_rgb = mean_rgb
self.stddev_rgb = stddev_rgb
self.use_augmentation = use_augmentationdef __call__(self, image, label, is_training=True):
if self.use_augmentation:
return self._preprocess_with_augmentation(image, label, is_training)
return self._preprocess_without_augmentation(image, label)def _preprocess_with_augmentation(self, image, label, is_training):
"""Image preprocessing method with data augmentation."""
image_size = self.input_shape[0]
if is_training:
image = preprocess_for_train(image, image_size)
else:
image = preprocess_for_eval(image, image_size)image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)label = tf.one_hot(label, depth=self.num_classes)
return image, label# TODO(yuqili): Changes to preprocess to support batch input.
def _preprocess_without_augmentation(self, image, label):
"""Image preprocessing method without data augmentation."""
image = tf.cast(image, tf.float32)image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)image = tf.compat.v1.image.resize(image, self.input_shape)
label = tf.one_hot(label, depth=self.num_classes)
return image, label
文章图片
class ImageClassifier(classification_model.ClassificationModel):
"""ImageClassifier class for inference and exporting to tflite."""def __init__(self,
model_spec,
index_to_label,
shuffle=True,
hparams=hub_lib.get_default_hparams(),
use_augmentation=False,
representative_data=https://www.it610.com/article/None):"""Init function for ImageClassifier class.Args:
model_spec: Specification for the model.
index_to_label: A list that map from index to label class name.
shuffle: Whether the data should be shuffled.
hparams: A namedtuple of hyperparameters. This function expects
.dropout_rate: The fraction of the input units to drop, used in dropout
layer.
.do_fine_tuning: If true, the Hub module is trained together with the
classification layer on top.
use_augmentation: Use data augmentation for preprocessing.
representative_data:Representative dataset for full integer
quantization. Used when converting the keras model to the TFLite model
with full interger quantization.
"""
super(ImageClassifier, self).__init__(model_spec, index_to_label, shuffle,
hparams.do_fine_tuning)
num_classes = len(index_to_label)
self._hparams = hparams
self.preprocess = image_preprocessing.Preprocessor(
self.model_spec.input_image_shape,
num_classes,
self.model_spec.mean_rgb,
self.model_spec.stddev_rgb,
use_augmentation=use_augmentation)
self.history = None# Training history that returns from `keras_model.fit`.
self.representative_data = https://www.it610.com/article/representative_datadef _get_tflite_input_tensors(self, input_tensors):"""Gets the input tensors for the TFLite model."""
return input_tensorsdef create_model(self, hparams=None, with_loss_and_metrics=False):
"""Creates the classifier model for retraining."""
hparams = self._get_hparams_or_default(hparams)module_layer = hub_loader.HubKerasLayerV1V2(
self.model_spec.uri, trainable=hparams.do_fine_tuning)
self.model = hub_lib.build_model(module_layer, hparams,
self.model_spec.input_image_shape,
self.num_classes)
if with_loss_and_metrics:
# Adds loss and metrics in the keras model.
self.model.compile(
loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
metrics=['accuracy'])
文章图片
Custom classification model that is already retained by dataclass ClassificationModel(custom_model.CustomModel):
""""The abstract base class that represents a Tensorflow classification model."""DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL)
ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL,
ExportFormat.SAVED_MODEL, ExportFormat.TFJS)def __init__(self, model_spec, index_to_label, shuffle, train_whole_model):
"""Initialize a instance with data, deploy mode and other related parameters.Args:
model_spec: Specification for the model.
index_to_label: A list that map from index to label class name.
shuffle: Whether the data should be shuffled.
train_whole_model: If true, the Hub module is trained together with the
classification layer on top. Otherwise, only train the top
classification layer.
"""
super(ClassificationModel, self).__init__(model_spec, shuffle)
self.index_to_label = index_to_label
self.num_classes = len(index_to_label)
self.train_whole_model = train_whole_modeldef evaluate(self, data, batch_size=32):
"""Evaluates the model.Args:
data: Data to be evaluated.
batch_size: Number of samples per evaluation step.Returns:
The loss value and accuracy.
"""
ds = data.gen_dataset(
batch_size, is_training=False, preprocess=self.preprocess)
return self.model.evaluate(ds)def predict_top_k(self, data, k=1, batch_size=32):
"""Predicts the top-k predictions.
文章图片
class CustomModel(abc.ABC):
""""The abstract base class that represents a Tensorflow classification model."""DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE)
ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.SAVED_MODEL,
ExportFormat.TFJS)def __init__(self, model_spec, shuffle):
"""Initialize a instance with data, deploy mode and other related parameters.Args:
model_spec: Specification for the model.
shuffle: Whether the training data should be shuffled.
"""
self.model_spec = model_spec
self.shuffle = shuffle
self.model = None
# TODO(yuqili): remove this method once preprocess for image classifier is
# also moved to DataLoader part.
self.preprocess = None@abc.abstractmethod
def train(self, train_data, validation_data=https://www.it610.com/article/None, **kwargs):
returndef summary(self):
self.model.summary()@abc.abstractmethod
def evaluate(self, data, **kwargs):
return
文章图片
def export_tflite(model,
tflite_filepath,
quantization_config=None,
convert_from_saved_model_tf2=False,
preprocess=None,
supported_ops=(tf.lite.OpsSet.TFLITE_BUILTINS,)):
"""Converts the retrained model to tflite format and saves it.Args:
model: model to be converted to tflite.
tflite_filepath: File path to save tflite model.
quantization_config: Configuration for post-training quantization.
convert_from_saved_model_tf2: Convert to TFLite from saved_model in TF 2.x.
preprocess: A preprocess function to apply on the dataset.
# TODO(wangtz): Remove when preprocess is split off from CustomModel.
supported_ops: A list of supported ops in the converted TFLite file.
"""
if tflite_filepath is None:
raise ValueError(
"TFLite filepath couldn't be None when exporting to tflite.")if compat.get_tf_behavior() == 1:
lite = tf.compat.v1.lite
else:
lite = tf.liteconvert_from_saved_model = (
compat.get_tf_behavior() == 1 or convert_from_saved_model_tf2)
with _create_temp_dir(convert_from_saved_model) as temp_dir_name:
if temp_dir_name:
save_path = os.path.join(temp_dir_name, 'saved_model')
model.save(save_path, include_optimizer=False, save_format='tf')
converter = lite.TFLiteConverter.from_saved_model(save_path)
else:
converter = lite.TFLiteConverter.from_keras_model(model)if quantization_config:
converter = quantization_config.get_converter_with_quantization(
converter, preprocess=preprocess)converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert()with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
f.write(tflite_model)def get_lite_runner(tflite_filepath, model_spec=None):
"""Gets `LiteRunner` from file path to TFLite model and `model_spec`."""
# Gets the functions to handle the input & output indexes if exists.
reorder_input_details_fn = None
if hasattr(model_spec, 'reorder_input_details'):
reorder_input_details_fn = model_spec.reorder_input_detailsreorder_output_details_fn = None
if hasattr(model_spec, 'reorder_output_details'):
reorder_output_details_fn = model_spec.reorder_output_detailslite_runner = LiteRunner(tflite_filepath, reorder_input_details_fn,
reorder_output_details_fn)
return lite_runner
推荐阅读
- Android|Android sqlite3数据库入门系列
- 为Google|为Google Cloud配置深度学习环境(CUDA、cuDNN、Tensorflow2、VScode远程ssh等)
- 使用sqlalchemy|使用sqlalchemy orm 的model序列化,解决返回model的异常
- ubuntu16.04-caffe-tensorflow安装教程
- demo1:|demo1: Tensorflow实现Linear regression
- tensorflow的堆叠多层RNN
- LiteOS云端对接教程03-LiteOS基于MQTT对接EMQ-X服务器
- TensorFlow实战学习(1-人工智能&机器学习介绍)
- Tensorflow|Tensorflow学习笔记----梯度下降
- Tensorflow【branch-官网代码实践-Eager/tf.data/Keras/Graph】_8.19