常见网络层类 对于常见的神经网络层的实现方式:
- 使用张量方式 利用 底层接口函数 实现。接口函数一般放在tf.nn模块中。
- 直接用层的方式搭建。tf.kearas.layers中提供了大量常见神经网络的类:
全连接层、激活函数层、池化层、卷积层、循环神经网络层。
调用__call__函数完成前向计算(前向计算的逻辑自动保存于__call__中)。
例子:
两种方式使用softmax。
# 1.手动实现softmax计算
x = tf.constant([1, 2, 3], dtype=tf.float32)
print(tf.nn.softmax(x))# softmax函数输入tensor的dtype只能是float
# 2.将softmax看作一个激活层,先定义层类,调用__call__前向传播
soft_layer = layers.Softmax()
print(soft_layer(x))
tf.Tensor([0.09003057 0.24472848 0.66524094], shape=(3,), dtype=float32)
tf.Tensor([0.09003057 0.24472848 0.66524094], shape=(3,), dtype=float32)
网络容器Sequential 前面说直接调用层类对象的__call__就可以实现前向传播,但是如果层多的话就很麻烦,一层运算要写一行代码。
Sequential 容器可以将很多层封起来,组成一个大的网络类。这样就一行代码就可以实现前向传播。
Sequential 容器还可以动态添加网络层。使用add()方法即可。
创建网络层类对象时,其实并没有马上就创建 内部的权值变量【而且刚创建的时候层对象也不知道输入自变量的维度,也没办法初始化参数】。有两种方法可以让其开始初始化参数:
- 前向传播一次。
- 调用build方法指定输入的自变量大小。
当我们通过 Sequential 容量封装多个网络层时,每层的参数列表将会自动并入Sequential 容器的参数列表中,不需要人为合并网络参数列表,这也是 Sequential 容器的便捷之处。
trainable_variables查看优化变量。
★想要快速搭建神经网络,就多使用Sequential容器。
举例:
# 用Sequential创建网络模块
network = Sequential([layers.Dense(12),
layers.ReLU()])
network.add(layers.Dense(6, activation=tf.nn.relu))
network.add(layers.Dense(3))
network.add(layers.Softmax())# 注意:上面的过程,各层并没有初始化优化参数。所以也不能调用summary
# 先build()初始化参数
network.build(input_shape=(2, 12))
print(network.summary())# 查看层信息
print(network.trainable_variables)# 查看待优化参数
Model: "sequential"
_________________________________________________________________
Layer (type)Output ShapeParam #
=================================================================
dense (Dense)(2, 12)156
_________________________________________________________________
re_lu (ReLU)(2, 12)0
_________________________________________________________________
dense_1 (Dense)(2, 6)78
_________________________________________________________________
dense_2 (Dense)(2, 3)21
_________________________________________________________________
softmax_1 (Softmax)(2, 3)0
=================================================================
Total params: 255
Trainable params: 255
Non-trainable params: 0
_________________________________________________________________
None
[, , , , , ]Process finished with exit code 0
模型装配、训练与测试 神经网络优化/训练 的 理论流程:
搭建完网络框架以后,进行前向传播计算,其中前向传播的最后一步一般是损失函数的计算。
前向传播的过程中会记录计算图。
使用反向传播BP算法,利用前向传播时创建的计算图,反向路径计算所有待优化参数的梯度。
得到梯度以后,选择优化算法,对参数进行 优化。
【因为优化算法绝大部分是梯度相关的算法,因此BP算法计算梯度非常重要,是深度学习的基石。】
训练的一般流程:
将样本数据分成多个batch进行循环。
每一个batch时,先前向传播计算损失,然后反向传播计算梯度。【损失函数也要自己事先确定好】
得到当前梯度后,进行一步梯度迭代。【具体的梯度迭代形式取决于用什么优化算法】
…
直至达到规定的精度 or 最大迭代次数。
如果手动实现上述流程,代码量不少,也要考虑很多细节。但是,
Keras提供了compile()和fit()高层函数来帮我们实现上述流程,
只需两行代码,轻松搞定。
装配:compile()函数先指定 优化器对象、损失函数类型、评价指标
模型训练:fit()函数将训练集 和 验证集 数据送入。
history = network.fit(train_db, epochs=5, validation_data=https://www.it610.com/article/val_db, validation_freq=2)
train_db 为 tf.data.Dataset 对象,也可以传入 Numpy Array 类型的数据
epochs 参数指定训练迭代的 Epoch 数量
validation_data 参数指定用于验证(测试)的数据集和验证的频率validation_freq
其中 history.history 为字典对象,包含了训练过程中的 loss、测量指标等记录项,我们可以直接查看这些训练数据
{'accuracy': [0.00011666667, 0.0, 0.0, 0.010666667, 0.02495],
'loss': [2465719710540.5845, # 历史训练误差
78167808898516.03,
404488834518159.6,
1049151145155144.4,
1969370184858451.0],
'val_accuracy': [0.0, 0.0], # 历史验证准确率
# 历史验证误差
'val_loss': [197178788071657.3, 1506234836955706.2]}
- 可以看到通过 compile&fit 方 式实现的代码非常简洁和高效,大大缩减了开发时间。但是因为接口非常高层,灵活性也降低了,是否使用需要用户自行判断。
数据加载进入内存后,需要转换成 Dataset 对象,才能利用 TensorFlow 提供的各种便捷功能。
train_db = train_db.shuffle(10000) # 随机打散样本,不会打乱样本与标签映射关系
train_db = train_db.batch(128) # 设置批训练,batch size 为 128
y = tf.one_hot(y, depth=10) # one-hot 编码
- 对于 Dataset 对象,进行一个epoch训练
for step, (x,y) in enumerate(train_db): # 迭代数据集对象,带 step 参数
或
for x,y in train_db: # 迭代数据集对象
- 设置epoch数:
for epoch in range(20): # 训练 Epoch 数
for step, (x,y) in enumerate(train_db): # 迭代 Step 数
# training...
【tensorflow学习|tensorflo之keras高层接口】注:batch这个属性是嵌入在train_db里的。
推荐阅读
- Python机器学习基础教程|Python机器学习日记7(朴素贝叶斯分类器(持续更新))
- 数据库|md5解密
- Opencv项目实战|Opencv项目实战(07 人脸识别和考勤系统)
- 可视化|偶然发现的Python自学宝藏地带!
- Python|Python分析淘宝月饼销售数据,五仁月饼王者地位不可动摇!
- 程序员|爬取某宝4000条数据,用Python做了一个 “月饼“ 可视化大屏,过中秋
- python基础技能|python入门之时间处理日期库
- 数据分析|Python分析淘宝月饼销售数据,五仁月饼王者地位不可动摇
- python|中秋味的可视化大屏 【以python pyecharts为工具】