分类|基于vgg16的猫狗识别(二分类)

基于vgg16的猫狗识别(二分类)
python代码如下:

from tensorflow.keras.applications import VGG16 from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Flatten from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.optimizers import RMSprop import matplotlib.pyplot as plttrain_path = r"D:\Desktop\Dog_Cat\train" #训练集目录 valid_path = r"D:\Desktop\Dog_Cat\valid" #验证集目录 i1 = ImageDataGenerator(rescale=1/255, rotation_range=40, width_shift_range=0.2) i2 = ImageDataGenerator(rescale=1/255) f1 = i1.flow_from_directory(train_path, target_size=(150, 150), batch_size=20, class_mode="binary") f2 = i1.flow_from_directory(valid_path, target_size=(150, 150), batch_size=20, class_mode="binary") # 2.构建模型 model = Sequential() vgg = VGG16(include_top=False, input_shape=(150, 150, 3)) vgg.summary() for i, j in enumerate(vgg.layers): if i >= 17: j.trainable = True else: j.trainable = False vgg.summary() model.add(vgg) model.add(Flatten()) model.add(Dense(units=1, activation="sigmoid")) model.compile(optimizer=RMSprop(learning_rate=1E-4), loss="binary_crossentropy", metrics=["acc"]) model.summary() history = model.fit_generator(generator=f1, epochs=15, validation_data=https://www.it610.com/article/f2) plt.rcParams['font.family'] = ['sans-serif'] plt.rcParams['font.sans-serif'] = ['SimHei'] accuracy = history.history['acc'] val_accuracy = history.history['val_acc'] epochs = range(1, len(accuracy)+1) plt.plot(epochs, accuracy, label='训练精度', c = 'r') plt.plot(epochs, val_accuracy, label='验证精度', c = 'b') plt.xlabel('epochs') plt.ylabel('accuracy') plt.title('author:peiyuanman') plt.legend() plt.show()model.save('DogCatModelVGG16.h5')

    推荐阅读