Tensorflow高层封装Estimator-DNNClassifier
直接上代码:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#将日志信息输入到屏幕
tf.logging.set_verbosity(tf.logging.INFO)mnist = input_data.read_data_sets("../../datasets/MNIST_data", one_hot=False)# 定义模型的输入。指定的输入都会拼接在一起
feature_columns = [tf.feature_column.numeric_column("image", shape=[784])]# 通过DNNClassifier定义模型。DNNClassifier只能定义多层全连接层神经网络
#hidden_units定义各层隐藏层的的节点个数
estimator = tf.estimator.DNNClassifier(feature_columns=feature_columns,
hidden_units=[500],
n_classes=10,#输出层节点数
optimizer=tf.train.AdamOptimizer(),#优化函数
model_dir="log")#保存目录
Extracting ../../datasets/MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../datasets/MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../datasets/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_task_type': 'worker', '_is_chief': True, '_cluster_spec': , '_save_checkpoints_steps': None, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_num_ps_replicas': 0, '_tf_random_seed': None, '_master': '', '_num_worker_replicas': 1, '_task_id': 0, '_log_step_count_steps': 100, '_model_dir': 'log', '_save_summary_steps': 100}
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"image": mnist.train.images},#需要给出所有的输入数据
y=mnist.train.labels.astype(np.int32),对应x的正确标签
num_epochs=None,#数据循环使用轮数,这里不循环使用
batch_size=128,#每个batch的大小
shuffle=True)#是否要对数据进行随机打乱操作estimator.train(input_fn=train_input_fn, steps=10000)#训练模型,DNNClassifier定义的模型默认使用交叉熵作为损失函数
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into log/model.ckpt.
INFO:tensorflow:loss = 304.06, step = 1
INFO:tensorflow:global_step/sec: 37.6473
INFO:tensorflow:loss = 22.9116, step = 101 (2.650 sec)
INFO:tensorflow:global_step/sec: 36.5129
INFO:tensorflow:loss = 16.9818, step = 201 (2.738 sec)
INFO:tensorflow:global_step/sec: 37.3001
INFO:tensorflow:loss = 18.5241, step = 301 (2.681 sec)
INFO:tensorflow:global_step/sec: 37.91
INFO:tensorflow:loss = 22.4564, step = 401 (2.637 sec)
INFO:tensorflow:global_step/sec: 36.15
INFO:tensorflow:loss = 24.7065, step = 501 (2.767 sec)
INFO:tensorflow:global_step/sec: 35.9403
INFO:tensorflow:loss = 15.0627, step = 601 (2.782 sec)
INFO:tensorflow:global_step/sec: 37.5324
INFO:tensorflow:loss = 11.4905, step = 701 (2.664 sec)
INFO:tensorflow:global_step/sec: 37.3027
INFO:tensorflow:loss = 6.44032, step = 801 (2.682 sec)
INFO:tensorflow:global_step/sec: 36.9734
INFO:tensorflow:loss = 15.6424, step = 901 (2.704 sec)
INFO:tensorflow:global_step/sec: 36.7835
INFO:tensorflow:loss = 17.7588, step = 1001 (2.718 sec)
INFO:tensorflow:global_step/sec: 36.1868
INFO:tensorflow:loss = 5.8567, step = 1101 (2.764 sec)
INFO:tensorflow:global_step/sec: 37.21
INFO:tensorflow:loss = 5.6694, step = 1201 (2.688 sec)
INFO:tensorflow:global_step/sec: 36.9753
INFO:tensorflow:loss = 5.92369, step = 1301 (2.704 sec)
INFO:tensorflow:global_step/sec: 37.1743
INFO:tensorflow:loss = 4.01127, step = 1401 (2.690 sec)
INFO:tensorflow:global_step/sec: 37.0517
INFO:tensorflow:loss = 3.93887, step = 1501 (2.699 sec)
INFO:tensorflow:global_step/sec: 35.0115
INFO:tensorflow:loss = 17.9374, step = 1601 (2.859 sec)
INFO:tensorflow:global_step/sec: 36.3979
INFO:tensorflow:loss = 12.8719, step = 1701 (2.745 sec)
INFO:tensorflow:global_step/sec: 37.053
INFO:tensorflow:loss = 4.94476, step = 1801 (2.698 sec)
INFO:tensorflow:global_step/sec: 37.2657
INFO:tensorflow:loss = 8.36541, step = 1901 (2.684 sec)
INFO:tensorflow:global_step/sec: 36.1325
INFO:tensorflow:loss = 2.80073, step = 2001 (2.767 sec)
INFO:tensorflow:global_step/sec: 35.9877
INFO:tensorflow:loss = 2.12194, step = 2101 (2.779 sec)
INFO:tensorflow:global_step/sec: 35.4156
INFO:tensorflow:loss = 5.8448, step = 2201 (2.825 sec)
INFO:tensorflow:global_step/sec: 36.4246
INFO:tensorflow:loss = 3.41064, step = 2301 (2.745 sec)
INFO:tensorflow:global_step/sec: 37.0978
INFO:tensorflow:loss = 5.94868, step = 2401 (2.696 sec)
INFO:tensorflow:global_step/sec: 36.8387
INFO:tensorflow:loss = 3.89859, step = 2501 (2.714 sec)
INFO:tensorflow:global_step/sec: 37.4655
INFO:tensorflow:loss = 1.92916, step = 2601 (2.669 sec)
INFO:tensorflow:global_step/sec: 36.8676
INFO:tensorflow:loss = 1.62655, step = 2701 (2.715 sec)
INFO:tensorflow:global_step/sec: 37.1859
INFO:tensorflow:loss = 1.53982, step = 2801 (2.687 sec)
INFO:tensorflow:global_step/sec: 37.2527
INFO:tensorflow:loss = 1.03436, step = 2901 (2.685 sec)
INFO:tensorflow:global_step/sec: 36.3516
INFO:tensorflow:loss = 1.27746, step = 3001 (2.750 sec)
INFO:tensorflow:global_step/sec: 36.8335
INFO:tensorflow:loss = 1.13165, step = 3101 (2.715 sec)
INFO:tensorflow:global_step/sec: 37.1274
INFO:tensorflow:loss = 1.94848, step = 3201 (2.694 sec)
INFO:tensorflow:global_step/sec: 37.2034
INFO:tensorflow:loss = 2.70218, step = 3301 (2.687 sec)
INFO:tensorflow:global_step/sec: 36.6637
INFO:tensorflow:loss = 1.92926, step = 3401 (2.728 sec)
INFO:tensorflow:global_step/sec: 37.3145
INFO:tensorflow:loss = 0.442604, step = 3501 (2.679 sec)
INFO:tensorflow:global_step/sec: 36.5398
INFO:tensorflow:loss = 0.900882, step = 3601 (2.737 sec)
INFO:tensorflow:global_step/sec: 37.3834
INFO:tensorflow:loss = 4.18892, step = 3701 (2.675 sec)
INFO:tensorflow:global_step/sec: 39.2497
INFO:tensorflow:loss = 1.48036, step = 3801 (2.552 sec)
INFO:tensorflow:global_step/sec: 38.6457
INFO:tensorflow:loss = 0.884206, step = 3901 (2.583 sec)
INFO:tensorflow:global_step/sec: 38.2075
INFO:tensorflow:loss = 0.948791, step = 4001 (2.617 sec)
INFO:tensorflow:global_step/sec: 38.4911
INFO:tensorflow:loss = 0.194128, step = 4101 (2.598 sec)
INFO:tensorflow:global_step/sec: 39.1245
INFO:tensorflow:loss = 0.927972, step = 4201 (2.556 sec)
INFO:tensorflow:global_step/sec: 37.9535
INFO:tensorflow:loss = 0.588909, step = 4301 (2.635 sec)
INFO:tensorflow:global_step/sec: 37.7356
INFO:tensorflow:loss = 0.640977, step = 4401 (2.648 sec)
INFO:tensorflow:global_step/sec: 38.9475
INFO:tensorflow:loss = 1.55033, step = 4501 (2.569 sec)
INFO:tensorflow:global_step/sec: 38.4368
INFO:tensorflow:loss = 0.727761, step = 4601 (2.602 sec)
INFO:tensorflow:global_step/sec: 38.6664
INFO:tensorflow:loss = 0.875689, step = 4701 (2.586 sec)
INFO:tensorflow:global_step/sec: 38.4306
INFO:tensorflow:loss = 0.840781, step = 4801 (2.602 sec)
INFO:tensorflow:global_step/sec: 38.5601
INFO:tensorflow:loss = 0.192555, step = 4901 (2.593 sec)
INFO:tensorflow:global_step/sec: 37.9238
INFO:tensorflow:loss = 0.235096, step = 5001 (2.637 sec)
INFO:tensorflow:global_step/sec: 38.1754
INFO:tensorflow:loss = 1.71132, step = 5101 (2.621 sec)
INFO:tensorflow:global_step/sec: 37.7428
INFO:tensorflow:loss = 0.565316, step = 5201 (2.648 sec)
INFO:tensorflow:global_step/sec: 38.406
INFO:tensorflow:loss = 0.435648, step = 5301 (2.604 sec)
INFO:tensorflow:global_step/sec: 39.0185
INFO:tensorflow:loss = 0.427775, step = 5401 (2.563 sec)
INFO:tensorflow:global_step/sec: 38.5341
INFO:tensorflow:loss = 0.277333, step = 5501 (2.594 sec)
INFO:tensorflow:global_step/sec: 38.068
INFO:tensorflow:loss = 0.222045, step = 5601 (2.629 sec)
INFO:tensorflow:global_step/sec: 32.132
INFO:tensorflow:loss = 0.717366, step = 5701 (3.113 sec)
INFO:tensorflow:global_step/sec: 31.4068
INFO:tensorflow:loss = 1.14066, step = 5801 (3.181 sec)
INFO:tensorflow:global_step/sec: 36.2213
INFO:tensorflow:loss = 0.0925966, step = 5901 (2.761 sec)
INFO:tensorflow:global_step/sec: 35.4066
INFO:tensorflow:loss = 0.334483, step = 6001 (2.826 sec)
INFO:tensorflow:global_step/sec: 38.1624
INFO:tensorflow:loss = 2.09814, step = 6101 (2.619 sec)
INFO:tensorflow:global_step/sec: 36.8427
INFO:tensorflow:loss = 0.0559563, step = 6201 (2.723 sec)
INFO:tensorflow:global_step/sec: 37.4233
INFO:tensorflow:loss = 0.173522, step = 6301 (2.663 sec)
INFO:tensorflow:global_step/sec: 37.2928
INFO:tensorflow:loss = 1.80461, step = 6401 (2.683 sec)
INFO:tensorflow:global_step/sec: 37.4197
INFO:tensorflow:loss = 0.500693, step = 6501 (2.673 sec)
INFO:tensorflow:global_step/sec: 37.2329
INFO:tensorflow:loss = 0.564081, step = 6601 (2.684 sec)
INFO:tensorflow:global_step/sec: 35.6874
INFO:tensorflow:loss = 2.12218, step = 6701 (2.802 sec)
INFO:tensorflow:global_step/sec: 37.8897
INFO:tensorflow:loss = 0.913451, step = 6801 (2.639 sec)
INFO:tensorflow:global_step/sec: 36.9869
INFO:tensorflow:loss = 0.121778, step = 6901 (2.703 sec)
INFO:tensorflow:global_step/sec: 37.2794
INFO:tensorflow:loss = 1.24115, step = 7001 (2.684 sec)
INFO:tensorflow:global_step/sec: 36.5617
INFO:tensorflow:loss = 2.33691, step = 7101 (2.734 sec)
INFO:tensorflow:global_step/sec: 37.2914
INFO:tensorflow:loss = 0.194615, step = 7201 (2.682 sec)
INFO:tensorflow:global_step/sec: 37.4636
INFO:tensorflow:loss = 1.01934, step = 7301 (2.669 sec)
INFO:tensorflow:global_step/sec: 36.3987
INFO:tensorflow:loss = 0.0396438, step = 7401 (2.748 sec)
INFO:tensorflow:global_step/sec: 37.0134
INFO:tensorflow:loss = 0.318336, step = 7501 (2.701 sec)
INFO:tensorflow:global_step/sec: 37.4563
INFO:tensorflow:loss = 0.295077, step = 7601 (2.670 sec)
INFO:tensorflow:global_step/sec: 36.1086
INFO:tensorflow:loss = 0.300212, step = 7701 (2.770 sec)
INFO:tensorflow:global_step/sec: 37.1408
INFO:tensorflow:loss = 0.0957186, step = 7801 (2.691 sec)
INFO:tensorflow:global_step/sec: 37.3442
INFO:tensorflow:loss = 0.207393, step = 7901 (2.678 sec)
INFO:tensorflow:global_step/sec: 36.8388
INFO:tensorflow:loss = 0.702379, step = 8001 (2.714 sec)
INFO:tensorflow:global_step/sec: 37.0512
INFO:tensorflow:loss = 0.118717, step = 8101 (2.699 sec)
INFO:tensorflow:global_step/sec: 37.1012
INFO:tensorflow:loss = 0.414811, step = 8201 (2.698 sec)
INFO:tensorflow:global_step/sec: 37.0407
INFO:tensorflow:loss = 0.0207637, step = 8301 (2.698 sec)
INFO:tensorflow:global_step/sec: 37.4401
INFO:tensorflow:loss = 0.915473, step = 8401 (2.671 sec)
INFO:tensorflow:global_step/sec: 35.7477
INFO:tensorflow:loss = 0.0145219, step = 8501 (2.797 sec)
INFO:tensorflow:global_step/sec: 36.8545
INFO:tensorflow:loss = 0.19546, step = 8601 (2.713 sec)
INFO:tensorflow:global_step/sec: 36.4362
INFO:tensorflow:loss = 0.162718, step = 8701 (2.745 sec)
INFO:tensorflow:global_step/sec: 38.0979
INFO:tensorflow:loss = 0.110679, step = 8801 (2.625 sec)
INFO:tensorflow:global_step/sec: 37.3399
INFO:tensorflow:loss = 0.364681, step = 8901 (2.679 sec)
INFO:tensorflow:global_step/sec: 34.5091
INFO:tensorflow:loss = 0.11277, step = 9001 (2.897 sec)
INFO:tensorflow:global_step/sec: 36.2612
INFO:tensorflow:loss = 0.0492335, step = 9101 (2.758 sec)
INFO:tensorflow:global_step/sec: 36.4491
INFO:tensorflow:loss = 1.55669, step = 9201 (2.743 sec)
INFO:tensorflow:global_step/sec: 37.5527
INFO:tensorflow:loss = 3.54684, step = 9301 (2.663 sec)
INFO:tensorflow:global_step/sec: 37.1559
INFO:tensorflow:loss = 0.794145, step = 9401 (2.691 sec)
INFO:tensorflow:global_step/sec: 37.3445
INFO:tensorflow:loss = 0.211793, step = 9501 (2.679 sec)
INFO:tensorflow:global_step/sec: 36.5885
INFO:tensorflow:loss = 0.34873, step = 9601 (2.734 sec)
INFO:tensorflow:global_step/sec: 36.6112
INFO:tensorflow:loss = 0.126815, step = 9701 (2.730 sec)
INFO:tensorflow:global_step/sec: 37.0562
INFO:tensorflow:loss = 0.021787, step = 9801 (2.699 sec)
INFO:tensorflow:global_step/sec: 37.4349
INFO:tensorflow:loss = 0.10613, step = 9901 (2.671 sec)
INFO:tensorflow:Saving checkpoints for 10000 into log/model.ckpt.
INFO:tensorflow:Loss for final step: 0.0945373.
#测试时的数据输入,与训练时的数据基本一致
test_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"image": mnist.test.images},
y=mnist.test.labels.astype(np.int32),
num_epochs=1,
batch_size=128,
shuffle=False)test_results = estimator.evaluate(input_fn=test_input_fn)
accuracy_score = test_results["accuracy"]
print("\nTest accuracy: %g %%" % (accuracy_score*100))print test_results
INFO:tensorflow:Starting evaluation at 2017-12-28-18:43:13
INFO:tensorflow:Restoring parameters from log/model.ckpt-10000
INFO:tensorflow:Finished evaluation at 2017-12-28-18:43:15
INFO:tensorflow:Saving dict for global step 10000: accuracy = 0.9818, average_loss = 0.0919714, global_step = 10000, loss = 11.642Test accuracy: 98.18 %
{'average_loss': 0.09197142, 'accuracy': 0.98180002, 'global_step': 10000, 'loss': 11.641952}
【Tensorflow高层封装Estimator-DNNClassifier】
推荐阅读
- 2020-04-07vue中Axios的封装和API接口的管理
- 基于|基于 antd 风格的 element-table + pagination 的二次封装
- python自定义封装带颜色的logging模块
- jQuery插件
- 正语
- 针对大型商场和高层建筑的消防泵房建设会有何讲究呢
- 使用Promise对微信小程序wx.request请求方法进行封装
- JavaScript|vue 基于axios封装request接口请求——request.js文件
- vue.js|vue中使用axios封装成request使用
- 为Google|为Google Cloud配置深度学习环境(CUDA、cuDNN、Tensorflow2、VScode远程ssh等)