Neural|图像分类 - cifar100 实验研究
为了解决 cifar100 val_acc 过低的问题,本质上是过拟合问题,所以特地去 papers with code 网站上看了下 cifar100 benchmark 目前第一名做到了多少,如下图所示,val_cc = 0.96,有点东西哈,所以目前要做的是研究 SAM (Sharpness-Aware Minimization),主要用于提升模型的泛化性。
文章图片
我这里先把拿到的代码跑了下,不过数据集是 cifar10,val_acc = 0.97,我觉得还是很稳的,目前正在跑 cifar100,不过代码是 Pytorch 版本的,后续需要迁移到 Tensorflow 上来。cifar10 训练截图如下所示。代码地址: https://github.com/davda54/sam
文章图片
文章图片
更新: 跑完 cifar100 了,但是 val_acc 和想象中的有差别吧,总的来说是比之前的 0.8 有提升了,目前是 val_acc = 0.83,训练截图如下所示
文章图片
文章图片
特别说明: 模型训练 log 里的 Name,比如 DenseNet121_RandomFlip_…/validation,实际上用的网络是 DenseNet121,数据增强用的是 RandAugmentation,可以忽略 RandomFlip,因为上一个 flower_photos 实验研究遗留下来的原因
数据集: visual_domain_decathlon/cifar100
Config description: Data based on “CIFAR-100”, with images resized isotropically to have a shorter size of 72 pixels
train: 40000 张图片
test: 10000 张图片
validation: 10000 张图片
类别数为 100
训练的时候采用 180 x 180 x 3
其中 NASNetMobile 特殊一些,需要 resize 成 224 x 224 x 3
第一阶段,我们利用在 ImageNet 上做过预训练的模型来做 feature extraction,意思就是要 freeze 预训练模型的卷积部分,然后只训练新添加的 top-classifier,训练结果如下图所示
文章图片
此处我们可以看到,val_acc 最高的是 ResNet50,值为 0.7421,其实最高的是 ResNet101,但是考虑到计算量,我们取 ResNet50。不过这里比较神奇的是 ResNet50 的 val_acc 竟然是最高的,猜测是数据集的分辨率大小问题,毕竟我们此次的任务,原始图像分辨率只有 72 x 72 x 3。
我们粘贴一下第一阶段的代码
rand_aug = iaa.RandAugment(n=3, m=7)def augment(images):
# Input to `augment()` is a TensorFlow tensor which
# is not supported by `imgaug`. This is why we first
# convert it to its `numpy` variant.
images = tf.cast(images, tf.uint8)
return rand_aug(images=images.numpy())AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.shuffle(buffer_size=len(train_ds)).cache().batch(batch_size).map(
lambda x, y: (tf.py_function(augment, [x], [tf.float32])[0], y), num_parallel_calls=AUTOTUNE).prefetch(
buffer_size=AUTOTUNE)
val_ds = val_ds.cache().batch(batch_size).prefetch(buffer_size=AUTOTUNE)preprocess_input = tf.keras.applications.resnet.preprocess_input
base_model = tf.keras.applications.ResNet101(input_shape=img_size,
include_top=False,
weights='imagenet')
这里,我没有粘贴全部的代码,如果需要查看源码,请到这里: https://github.com/MaoXianXin/Tensorflow_tutorial
文章图片
如上图所示,我们需要 checkout 对应的分支。
基于此,我们对 ResNet50 和 InceptionResNetV2 分别做了 fine-tune,结果如下所示
文章图片
此处未对第一阶段的所有模型做 fine-tune,从上图可以发现,还是 ResNet50 的 val_acc 略高,不过到这里为止,我们在 visual_domain_decathlon/cifar100 上的 val_acc 还是低了些,只有 0.8041,需要做改进。
preprocess_input = tf.keras.applications.resnet50.preprocess_input
base_model = tf.keras.applications.ResNet50(input_shape=img_size,
include_top=False,
weights='imagenet')
base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))# Fine-tune from this layer onwards
fine_tune_at = 120# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = Falseinputs = tf.keras.Input(shape=img_size)
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(num_classes)(x)
model = tf.keras.Model(inputs, outputs)
model.load_weights('./save_model/my_model_1')
print(model.summary())
【Neural|图像分类 - cifar100 实验研究】最后上一下 fine-tune 阶段的代码,这里需要注意的是,不同模型,网络层数不一样,所以 fine_tune_at 这个参数我们需要看情况而定,还有就是加载模型的地址不要搞错。
推荐阅读
- Java|Java OpenCV图像处理之SIFT角点检测详解
- jQuery插件
- 1.2序列通用操作
- ImageLoaders 加载图像
- 前沿论文|论文精读(Neural Architecture Search without Training)
- JAVA图像处理系列(四)——噪声
- 茶叶分类(五)(茶叶分为六大类,做茶的人只分两类)
- 使用交叉点观察器延迟加载图像以提高性能
- 神经网络Neural|神经网络Neural Networks
- Figure|Figure 图像