tensorflow2.0|TensorFlow 从入门到精通(8)—— 模型保存与恢复以及TensorFlow游乐场、Tensorboard

这节课,我们介绍三种保存模型的方法,另外介绍两个很有用的工具,一个是游乐场,一个是tensorboard,这里只是浅浅带过,以后会深入讨论
昨天没更新,属实是累了,下一篇卷积神经网络,冲冲冲
一、只保留权重和偏执
import tensorflow as tf from tensorflow import keras from tensorflow.keras import datasets, layers, optimizers # step1 加载训练集和测试集合 mnist = tf.keras.datasets.mnist (x_train, y_train),(x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 # step2 创建模型 def create_model(): return tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax') ]) model = create_model() # step3 编译模型 主要是确定优化方法,损失函数等 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # step4 模型训练 训练一个epochs model.fit(x=x_train, y=y_train, epochs=1, ) # step5 模型测试 loss, acc = model.evaluate(x_test, y_test) print("train model, accuracy:{:5.2f}%".format(100 * acc))

1875/1875 [==============================] - 9s 5ms/step - loss: 0.2195 - accuracy: 0.9350 313/313 [==============================] - 1s 2ms/step - loss: 0.1053 - accuracy: 0.9678 train model, accuracy:96.78%

# step6 保存模型的权重和偏置 model.save_weights('./min.h5')

# step7 删除模型 del model # step8 重新创建模型 model = create_model() model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # step9 恢复权重 model.load_weights('./min.h5') # step10 测试模型 loss, acc = model.evaluate(x_test, y_test) print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

313/313 [==============================] - 1s 2ms/step - loss: 0.1053 - accuracy: 0.9678 Restored model, accuracy:96.78%

二、保存整个模型
import tensorflow as tf from tensorflow import keras from tensorflow.keras import datasets, layers, optimizers # step1 加载训练集和测试集合 mnist = tf.keras.datasets.mnist (x_train, y_train),(x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 # step2 创建模型 def create_model(): return tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax') ]) model = create_model() # step3 编译模型 主要是确定优化方法,损失函数等 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # step4 模型训练 训练一个epochs model.fit(x=x_train, y=y_train, epochs=1, ) # step5 模型测试 loss, acc = model.evaluate(x_test, y_test) print("train model, accuracy:{:5.2f}%".format(100 * acc)) # step6 保存模型的权重和偏置 model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' # step7 删除模型 del model # deletes the existing model # step8 恢复模型 # returns a compiled model # identical to the previous one restored_model = tf.keras.models.load_model('my_model.h5') # step9 测试模型 loss, acc = restored_model.evaluate(x_test, y_test) print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

1875/1875 [==============================] - 9s 5ms/step - loss: 0.2190 - accuracy: 0.9355 313/313 [==============================] - 1s 2ms/step - loss: 0.1026 - accuracy: 0.9679 train model, accuracy:96.79% 313/313 [==============================] - 1s 2ms/step - loss: 0.1026 - accuracy: 0.9679 Restored model, accuracy:96.79%

三、通过回调函数实现断点续训
import tensorflow as tf from tensorflow import keras from tensorflow.keras import datasets, layers, optimizers # step1 加载训练集和测试集合 mnist = tf.keras.datasets.mnist (x_train, y_train),(x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 # step2 创建模型 def create_model(): return tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax') ]) model = create_model() # step3 编译模型 主要是确定优化方法,损失函数等 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # ————————————————————回调函数———————————————————————— logdir = './logs' checkpoint_path = './checkpoint/min.{epoch:02d}-{val_loss:.2f}.ckpt' def scheduler(epoch, lr): if epoch < 10: return lr else: return lr * tf.math.exp(-0.1)callbacks = [ # tensorboard tf.keras.callbacks.TensorBoard(log_dir=logdir,# 存放日志路径 histogram_freq=2), # 直方图频率 # 保存模型 tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, # 模型保存路径 save_weights_only=True,# 只保存权重和偏执 verbose=1, # 以进度条方式展示 period=1 #每五个周期(epoch)存一个文件 ), # 终止训练的回调函数 tf.keras.callbacks.EarlyStopping(monitor='val_loss',# 监控对象 patience=3),# 允许周期 # 超过3个周期,val_loss升高就停止,防止过拟合# 调整学习率 tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=0)]# step4 模型训练 model.fit(x=x_train, y=y_train, epochs=1, validation_split=0.2, callbacks=callbacks, ) # step5 模型测试 loss, acc = model.evaluate(x_test, y_test) print("train model, accuracy:{:5.2f}%".format(100 * acc)) # step6 保存模型的权重和偏置 model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' # step7 删除模型 del model # step8 重新创建模型 model = create_model() model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # step9 恢复权重 import os logdir = './logs' checkpoint_path = './checkpoint/min.{epoch:02d}-{val_loss:.2f}.ckpt' checkpoint_dir = os.path.dirname(checkpoint_path) latest = tf.train.latest_checkpoint(checkpoint_dir) # 查找最新的一条模型记录 model.load_weights(latest) # step10 测试模型 loss, acc = model.evaluate(x_test, y_test) print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of batches seen. 1500/1500 [==============================] - 9s 6ms/step - loss: 0.2449 - accuracy: 0.9279 - val_loss: 0.1252 - val_accuracy: 0.9632Epoch 00001: saving model to ./checkpoint/min.01-0.13.ckpt 313/313 [==============================] - 1s 3ms/step - loss: 0.1215 - accuracy: 0.9617 train model, accuracy:96.17% 313/313 [==============================] - 1s 2ms/step - loss: 0.1215 - accuracy: 0.9617 Restored model, accuracy:96.17%

四、游乐场
  • 一个web可视化神经网络程序,当然很简单,有助于理解神经网络
【tensorflow2.0|TensorFlow 从入门到精通(8)—— 模型保存与恢复以及TensorFlow游乐场、Tensorboard】tensorflow游乐场链接
tensorflow2.0|TensorFlow 从入门到精通(8)—— 模型保存与恢复以及TensorFlow游乐场、Tensorboard
文章图片

五、tensorboard
%load_ext tensorboard %tensorboard --logdir logs

The tensorboard extension is already loaded. To reload it, use: %reload_ext tensorboardReusing TensorBoard on port 6006 (pid 776), started 0:00:45 ago. (Use '!kill 776' to kill it.)


    推荐阅读