Tensorflow|Tensorflow Lite Model Maker --- 图像分类篇+源码

TFLite_tutorials The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow neural-network model to particular input data when deploying this model for on-device ML applications.
解读: 此处我们想要得到的是 .tflite 格式的模型,用于在移动端或者嵌入式设备上进行部署
下表罗列的是 TFLite Model Maker 目前支持的几个任务类型

Supported Tasks Task Utility
Image Classification: tutorial, api Classify images into predefined categories.
Object Detection: tutorial, api Detect objects in real time.
Text Classification: tutorial, api Classify text into predefined categories.
BERT Question Answer: tutorial, api Find the answer in a certain context for a given question with BERT.
Audio Classification: tutorial, api Classify audio into predefined categories.
Recommendation: demo, api Recommend items based on the context information for on-device scenario.
If your tasks are not supported, please first use TensorFlow to retrain a TensorFlow model with transfer learning (following guides like images, text, audio) or train it from scratch, and then convert it to TensorFlow Lite model.
解读: 如果你要训练的模型不符合上述的任务类型,那么可以先训练 Tensorflow Model 然后再转换成 TFLite
【Tensorflow|Tensorflow Lite Model Maker --- 图像分类篇+源码】想用使用 Tensorflow Lite Model Maker 我们需要先安装:
pip install tflite-model-maker

更换不同的模型,看最终的准确率,以及 TFLite 的大小、推断速度、内存占用、CPU占用等
image_path = tf.keras.utils.get_file( 'flower_photos.tgz', 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', extract=True) image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')

|__ daisy
|______ 100080576_f52e8ee070_n.jpg |______ 14167534527_781ceb1b7a_n.jpg |______ ...

|__ dandelion
|______ 10043234166_e6dd915111_n.jpg |______ 1426682852_e62169221f_m.jpg |______ ...

|__ roses
|______ 102501987_3cdb8e5394_n.jpg |______ 14982802401_a3dfb22afb.jpg |______ ...

|__ sunflowers
|______ 12471791574_bb1be83df4.jpg |______ 15122112402_cafa41934f.jpg |______ ...

|__ tulips
|______ 13976522214_ccec508fe7.jpg |______ 14487943607_651e8062a1_m.jpg |______ ...

data = https://www.it610.com/article/DataLoader.from_folder(image_path) train_data, rest_data = data.split(0.8) validation_data, test_data = rest_data.split(0.5)

assert tf.__version__.startswith('2')

判断是否为 '2' 开头
模型训练结果 train_acc = 0.9698, val_acc = 0.9375, test_acc = 0.9210 总体来说符合模型的泛化规律
import os import timeimport numpy as np import tensorflow as tf from tflite_model_maker import model_spec from tflite_model_maker import image_classifier from tflite_model_maker.config import ExportFormat from tflite_model_maker.config import QuantizationConfig from tflite_model_maker.image_classifier import DataLoader import matplotlib.pyplot as pltassert tf.__version__.startswith('2')image_path = tf.keras.utils.get_file( 'flower_photos.tgz', 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', extract=True) image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')data = https://www.it610.com/article/DataLoader.from_folder(image_path) # data = data.gen_dataset(batch_size=1) train_data, rest_data = data.split(0.8) # for batch in data.take(1): #print(batch) #breakvalidation_data, test_data = rest_data.split(0.5)model = image_classifier.create(train_data, validation_data=https://www.it610.com/article/validation_data, model_spec=model_spec.get('efficientnet_lite0'), epochs=20)loss, accuracy = model.evaluate(test_data)model.export(export_dir='./testTFlite', export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))start = time.time() print(model.evaluate_tflite('./testTFlite/model.tflite', test_data)) end = time.time() print('elapsed time: ', end - start)

从上面的输出日志来看,模型经过量化后,准确率并未有多少损失,量化后的模型大小为 4.0MB(efficientnet_lite0)
从下图来看,是单 cpu 在做推断,test_data 的图片有 367 张,总耗时 273.43s
config = QuantizationConfig.for_float16() model.export(export_dir='./testTFlite', tflite_filename='model_fp16.tflite', quantization_config=config, export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))

如果导出的模型是 fp16 的话,模型大小为 6.8MB(efficientnet_lite0),推断速度是 5.54 s,快了很多
model = image_classifier.create(train_data, validation_data=https://www.it610.com/article/validation_data, model_spec=model_spec.get('mobilenet_v2'), epochs=20)

将模型切换为 mobilenet_v2,导出的 fp16 模型大小为 4.6MB,推断速度是 4.36 s
inception_v3_spec = image_classifier.ModelSpec( uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1') inception_v3_spec.input_image_shape = [299, 299] model = image_classifier.create(train_data, validation_data=https://www.it610.com/article/validation_data, model_spec=inception_v3_spec, epochs=20)

将模型切换为 inception_v3,导出的 fp16 模型大小为 43.8MB(inception_v3),推断速度是 25.31 s
Common Dataset used for tasks.class DataLoader(object): """This class provides generic utilities for loading customized domain data that will be used later in model retraining.For different ML problems or tasks, such as image classification, text classification etc., a subclass is provided to handle task-specific data loading requirements. """def __init__(self, dataset, size): """Init function for class `DataLoader`.In most cases, one should use helper functions like `from_folder` to create an instance of this class.Args: dataset: A tf.data.Dataset object that contains a potentially large set of elements, where each element is a pair of (input_data, target). The `input_data` means the raw input data, like an image, a text etc., while the `target` means some ground truth of the raw input data, such as the classification label of the image etc. size: The size of the dataset. tf.data.Dataset donesn't support a function to get the length directly since it's lazy-loaded and may be infinite. """ self._dataset = dataset self._size = sizedef gen_dataset(self, batch_size=1, is_training=False, shuffle=False, input_pipeline_context=None, preprocess=None, drop_remainder=False): """Generate a shared and batched tf.data.Dataset for training/evaluation.

Image dataloaderclass ImageClassifierDataLoader(dataloader.ClassificationDataLoader): """DataLoader for image classifier."""@classmethod def from_folder(cls, filename, shuffle=True): """Image analysis for image classification load images with labels.Assume the image data of the same label are in the same subdirectory.Args: filename: Name of the file. shuffle: boolean, if shuffle, random shuffle data.Returns: ImageDataset containing images and labels and other related info. """ @classmethod def from_tfds(cls, name): """Loads data from tensorflow_datasets."""

ImageNet preprocessingclass Preprocessor(object): """Preprocessing for image classification."""def __init__(self, input_shape, num_classes, mean_rgb, stddev_rgb, use_augmentation=False): self.input_shape = input_shape self.num_classes = num_classes self.mean_rgb = mean_rgb self.stddev_rgb = stddev_rgb self.use_augmentation = use_augmentationdef __call__(self, image, label, is_training=True): if self.use_augmentation: return self._preprocess_with_augmentation(image, label, is_training) return self._preprocess_without_augmentation(image, label)def _preprocess_with_augmentation(self, image, label, is_training): """Image preprocessing method with data augmentation.""" image_size = self.input_shape[0] if is_training: image = preprocess_for_train(image, image_size) else: image = preprocess_for_eval(image, image_size)image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype) image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)label = tf.one_hot(label, depth=self.num_classes) return image, label# TODO(yuqili): Changes to preprocess to support batch input. def _preprocess_without_augmentation(self, image, label): """Image preprocessing method without data augmentation.""" image = tf.cast(image, tf.float32)image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype) image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)image = tf.compat.v1.image.resize(image, self.input_shape) label = tf.one_hot(label, depth=self.num_classes) return image, label

class ImageClassifier(classification_model.ClassificationModel): """ImageClassifier class for inference and exporting to tflite."""def __init__(self, model_spec, index_to_label, shuffle=True, hparams=hub_lib.get_default_hparams(), use_augmentation=False, representative_data=https://www.it610.com/article/None):"""Init function for ImageClassifier class.Args: model_spec: Specification for the model. index_to_label: A list that map from index to label class name. shuffle: Whether the data should be shuffled. hparams: A namedtuple of hyperparameters. This function expects .dropout_rate: The fraction of the input units to drop, used in dropout layer. .do_fine_tuning: If true, the Hub module is trained together with the classification layer on top. use_augmentation: Use data augmentation for preprocessing. representative_data:Representative dataset for full integer quantization. Used when converting the keras model to the TFLite model with full interger quantization. """ super(ImageClassifier, self).__init__(model_spec, index_to_label, shuffle, hparams.do_fine_tuning) num_classes = len(index_to_label) self._hparams = hparams self.preprocess = image_preprocessing.Preprocessor( self.model_spec.input_image_shape, num_classes, self.model_spec.mean_rgb, self.model_spec.stddev_rgb, use_augmentation=use_augmentation) self.history = None# Training history that returns from `keras_model.fit`. self.representative_data = https://www.it610.com/article/representative_datadef _get_tflite_input_tensors(self, input_tensors):"""Gets the input tensors for the TFLite model.""" return input_tensorsdef create_model(self, hparams=None, with_loss_and_metrics=False): """Creates the classifier model for retraining.""" hparams = self._get_hparams_or_default(hparams)module_layer = hub_loader.HubKerasLayerV1V2( self.model_spec.uri, trainable=hparams.do_fine_tuning) self.model = hub_lib.build_model(module_layer, hparams, self.model_spec.input_image_shape, self.num_classes) if with_loss_and_metrics: # Adds loss and metrics in the keras model. self.model.compile( loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1), metrics=['accuracy'])

Custom classification model that is already retained by dataclass ClassificationModel(custom_model.CustomModel): """"The abstract base class that represents a Tensorflow classification model."""DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL) ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL, ExportFormat.SAVED_MODEL, ExportFormat.TFJS)def __init__(self, model_spec, index_to_label, shuffle, train_whole_model): """Initialize a instance with data, deploy mode and other related parameters.Args: model_spec: Specification for the model. index_to_label: A list that map from index to label class name. shuffle: Whether the data should be shuffled. train_whole_model: If true, the Hub module is trained together with the classification layer on top. Otherwise, only train the top classification layer. """ super(ClassificationModel, self).__init__(model_spec, shuffle) self.index_to_label = index_to_label self.num_classes = len(index_to_label) self.train_whole_model = train_whole_modeldef evaluate(self, data, batch_size=32): """Evaluates the model.Args: data: Data to be evaluated. batch_size: Number of samples per evaluation step.Returns: The loss value and accuracy. """ ds = data.gen_dataset( batch_size, is_training=False, preprocess=self.preprocess) return self.model.evaluate(ds)def predict_top_k(self, data, k=1, batch_size=32): """Predicts the top-k predictions.

class CustomModel(abc.ABC): """"The abstract base class that represents a Tensorflow classification model."""DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE) ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.SAVED_MODEL, ExportFormat.TFJS)def __init__(self, model_spec, shuffle): """Initialize a instance with data, deploy mode and other related parameters.Args: model_spec: Specification for the model. shuffle: Whether the training data should be shuffled. """ self.model_spec = model_spec self.shuffle = shuffle self.model = None # TODO(yuqili): remove this method once preprocess for image classifier is # also moved to DataLoader part. self.preprocess = None@abc.abstractmethod def train(self, train_data, validation_data=https://www.it610.com/article/None, **kwargs): returndef summary(self): self.model.summary()@abc.abstractmethod def evaluate(self, data, **kwargs): return

def export_tflite(model, tflite_filepath, quantization_config=None, convert_from_saved_model_tf2=False, preprocess=None, supported_ops=(tf.lite.OpsSet.TFLITE_BUILTINS,)): """Converts the retrained model to tflite format and saves it.Args: model: model to be converted to tflite. tflite_filepath: File path to save tflite model. quantization_config: Configuration for post-training quantization. convert_from_saved_model_tf2: Convert to TFLite from saved_model in TF 2.x. preprocess: A preprocess function to apply on the dataset. # TODO(wangtz): Remove when preprocess is split off from CustomModel. supported_ops: A list of supported ops in the converted TFLite file. """ if tflite_filepath is None: raise ValueError( "TFLite filepath couldn't be None when exporting to tflite.")if compat.get_tf_behavior() == 1: lite = tf.compat.v1.lite else: lite = tf.liteconvert_from_saved_model = ( compat.get_tf_behavior() == 1 or convert_from_saved_model_tf2) with _create_temp_dir(convert_from_saved_model) as temp_dir_name: if temp_dir_name: save_path = os.path.join(temp_dir_name, 'saved_model') model.save(save_path, include_optimizer=False, save_format='tf') converter = lite.TFLiteConverter.from_saved_model(save_path) else: converter = lite.TFLiteConverter.from_keras_model(model)if quantization_config: converter = quantization_config.get_converter_with_quantization( converter, preprocess=preprocess)converter.target_spec.supported_ops = supported_ops tflite_model = converter.convert()with tf.io.gfile.GFile(tflite_filepath, 'wb') as f: f.write(tflite_model)def get_lite_runner(tflite_filepath, model_spec=None): """Gets `LiteRunner` from file path to TFLite model and `model_spec`.""" # Gets the functions to handle the input & output indexes if exists. reorder_input_details_fn = None if hasattr(model_spec, 'reorder_input_details'): reorder_input_details_fn = model_spec.reorder_input_detailsreorder_output_details_fn = None if hasattr(model_spec, 'reorder_output_details'): reorder_output_details_fn = model_spec.reorder_output_detailslite_runner = LiteRunner(tflite_filepath, reorder_input_details_fn, reorder_output_details_fn) return lite_runner
