基于vgg16的猫狗识别(二分类)
python代码如下:
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import RMSprop
import matplotlib.pyplot as plttrain_path = r"D:\Desktop\Dog_Cat\train" #训练集目录
valid_path = r"D:\Desktop\Dog_Cat\valid" #验证集目录
i1 = ImageDataGenerator(rescale=1/255, rotation_range=40, width_shift_range=0.2)
i2 = ImageDataGenerator(rescale=1/255)
f1 = i1.flow_from_directory(train_path, target_size=(150, 150), batch_size=20, class_mode="binary")
f2 = i1.flow_from_directory(valid_path, target_size=(150, 150), batch_size=20, class_mode="binary")
# 2.构建模型
model = Sequential()
vgg = VGG16(include_top=False, input_shape=(150, 150, 3))
vgg.summary()
for i, j in enumerate(vgg.layers):
if i >= 17:
j.trainable = True
else:
j.trainable = False
vgg.summary()
model.add(vgg)
model.add(Flatten())
model.add(Dense(units=1, activation="sigmoid"))
model.compile(optimizer=RMSprop(learning_rate=1E-4), loss="binary_crossentropy", metrics=["acc"])
model.summary()
history = model.fit_generator(generator=f1, epochs=15, validation_data=https://www.it610.com/article/f2)
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']
accuracy = history.history['acc']
val_accuracy = history.history['val_acc']
epochs = range(1, len(accuracy)+1)
plt.plot(epochs, accuracy, label='训练精度', c = 'r')
plt.plot(epochs, val_accuracy, label='验证精度', c = 'b')
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.title('author:peiyuanman')
plt.legend()
plt.show()model.save('DogCatModelVGG16.h5')
推荐阅读
- anaconda|anaconda和ts
- 人工智能|PyTorchの可视化工具
- 深度学习|卷积神经网络 AlexNet详解
- 深度学习|Pytorch总结五之 模型选择、?拟合和过拟合
- #|【Task04】前沿学术数据分析AcademicTrends-论文种类分类
- nlp|NLP之文本分类(一)---文本分类描述
- NLP|Transformer - Attention Is All You Need - 跟李沐学AI
- #|新闻主题分类任务——torchtext 库进行文本分类
- 自然语言处理|自然语言处理(七)(AG_NEWS新闻分类任务(TORCHTEXT))