#|MobileViT(A mobile-friendly Transformer-based model for image classification)
参考论文:
https://arxiv.org/abs/2110.02178参考链接:
【#|MobileViT(A mobile-friendly Transformer-based model for image classification)】https://keras.io/examples/vision/mobilevit/?&continueFlag=78fd93df24725efdd87456ffe941e3d8一种基于transformer的图像分类器
Imports
import tensorflow as tffrom keras.applications import imagenet_utils
from tensorflow.keras import layers
from tensorflow import kerasimport tensorflow_datasets as tfds
import tensorflow_addons as tfatfds.disable_progress_bar()
Hyperparameters
# Values are from table 4.
patch_size = 4# 2x2, for the Transformer blocks.
image_size = 256
expansion_factor = 2# expansion factor for the MobileNetV2 blocks.
MobileViT utilities
The MobileViT architecture is comprised of the following blocks:
1、Strided 3x3 convolutions that process the input image.
2、MobileNetV2-style inverted residual blocks for downsampling the resolution of the intermediate feature maps.
3、MobileViT blocks that combine the benefits of Transformers and convolutions. It is presented in the figure below (taken from the original paper):
文章图片
def conv_block(x, filters=16, kernel_size=3, strides=2):
conv_layer = layers.Conv2D(
filters, kernel_size, strides=strides, activation=tf.nn.swish, padding="same"
)
return conv_layer(x)# Reference: https://git.io/JKgtCdef inverted_residual_block(x, expanded_channels, output_channels, strides=1):
m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
m = layers.BatchNormalization()(m)
m = tf.nn.swish(m)if strides == 2:
m = layers.ZeroPadding2D(padding=imagenet_utils.correct_pad(m, 3))(m)
m = layers.DepthwiseConv2D(
3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
)(m)
m = layers.BatchNormalization()(m)
m = tf.nn.swish(m)m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
m = layers.BatchNormalization()(m)if tf.math.equal(x.shape[-1], output_channels) and strides == 1:
return layers.Add()([m, x])
return m# Reference:
# https://keras.io/examples/vision/image_classification_with_vision_transformer/def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.swish)(x)
x = layers.Dropout(dropout_rate)(x)
return xdef transformer_block(x, transformer_layers, projection_dim, num_heads=2):
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(x)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, x])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=[x.shape[-1] * 2, x.shape[-1]], dropout_rate=0.1,)
# Skip connection 2.
x = layers.Add()([x3, x2])return xdef mobilevit_block(x, num_blocks, projection_dim, strides=1):
# Local projection with convolutions.
local_features = conv_block(x, filters=projection_dim, strides=strides)
local_features = conv_block(
local_features, filters=projection_dim, kernel_size=1, strides=strides
)# Unfold into patches and then pass through Transformers.
num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
local_features
)
global_features = transformer_block(
non_overlapping_patches, num_blocks, projection_dim
)# Fold into conv-like feature-maps.
folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
global_features
)# Apply point-wise conv -> concatenate with the input features.
folded_feature_map = conv_block(
folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
)
local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])# Fuse the local and global features using a convoluion layer.
local_global_features = conv_block(
local_global_features, filters=projection_dim, strides=strides
)return local_global_features
More on the MobileViT block:
def create_mobilevit(num_classes=5):
inputs = keras.Input((image_size, image_size, 3))
x = layers.Rescaling(scale=1.0 / 255)(inputs)# Initial conv-stem -> MV2 block.
x = conv_block(x, filters=16)
x = inverted_residual_block(
x, expanded_channels=16 * expansion_factor, output_channels=16
)# Downsampling with MV2 block.
x = inverted_residual_block(
x, expanded_channels=16 * expansion_factor, output_channels=24, strides=2
)
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=24
)
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=24
)# First MV2 -> MobileViT block.
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=48, strides=2
)
x = mobilevit_block(x, num_blocks=2, projection_dim=64)# Second MV2 -> MobileViT block.
x = inverted_residual_block(
x, expanded_channels=64 * expansion_factor, output_channels=64, strides=2
)
x = mobilevit_block(x, num_blocks=4, projection_dim=80)# Third MV2 -> MobileViT block.
x = inverted_residual_block(
x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
)
x = mobilevit_block(x, num_blocks=3, projection_dim=96)
x = conv_block(x, filters=320, kernel_size=1, strides=1)# Classification head.
x = layers.GlobalAvgPool2D()(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)return keras.Model(inputs, outputs)mobilevit_xxs = create_mobilevit()
mobilevit_xxs.summary()
Dataset preparation:
使用tf_flowers来验证模型
batch_size = 64
auto = tf.data.AUTOTUNE
resize_bigger = 280
num_classes = 5def preprocess_dataset(is_training=True):
def _pp(image, label):
if is_training:
# Resize to a bigger spatial resolution and take the random
# crops.
image = tf.image.resize(image, (resize_bigger, resize_bigger))
image = tf.image.random_crop(image, (image_size, image_size, 3))
image = tf.image.random_flip_left_right(image)
else:
image = tf.image.resize(image, (image_size, image_size))
label = tf.one_hot(label, depth=num_classes)
return image, labelreturn _ppdef prepare_dataset(dataset, is_training=True):
if is_training:
dataset = dataset.shuffle(batch_size * 10)
dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=auto)
return dataset.batch(batch_size).prefetch(auto)
Load and prepare the dataset:
train_dataset, val_dataset = tfds.load(
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
Train a MobileViT (XXS) model:
learning_rate = 0.002
label_smoothing_factor = 0.1
epochs = 30optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing_factor)def run_experiment(epochs=epochs):
mobilevit_xxs = create_mobilevit(num_classes=num_classes)
mobilevit_xxs.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])checkpoint_filepath = "/tmp/checkpoint"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True,
)mobilevit_xxs.fit(
train_dataset,
validation_data=https://www.it610.com/article/val_dataset,
epochs=epochs,
callbacks=[checkpoint_callback],
)
mobilevit_xxs.load_weights(checkpoint_filepath)
_, accuracy = mobilevit_xxs.evaluate(val_dataset)
print(f"Validation accuracy: {round(accuracy * 100, 2)}%")
return mobilevit_xxsmobilevit_xxs = run_experiment()
实验配置:tensorflow2.6.0+cuda11.2+cudnn8.1+python3.9
推荐阅读
- transformer|论文阅读|MobileViT
- CV|[Transformer]MobileViT(Light-weight, General-purpose, and Mobile-friendly Vision Transformer)
- 深度学习|轻量化网络结构MobileViT
- transformer|MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer论文翻译理解
- 网络|MobileViT 它来了!Apple 提出轻量、通用、适用于移动设备的Transformer!
- 人工智能|人工智能入门学习笔记(PyTorch)
- 如何使用tensorflow for mobile,开发环境为android studio
- android things sample(sample-tensorflow-imageclassifier)测试
- 前端|微信小程序(二)-- 项目实战