linux|tensorflow源码分析(七)-优化函数
Tensorflow系统中的优化函数主要根据传入的损失函数的梯度计算出损失函数的极值,在计算过程中会根据传入的学习率不断的修改模型中的参数,从而使神经模型在训练数据上的损失函数尽可能小,从而得到一个质量比较好的模型。
Tensorflow中实现了很多的优化函数:GradientDescentOptimizerAdagradDAOptimizer AdamOptimizerAdagradOptimizer 等,该文档主要介绍adamoptimizer优化函数,通过该接口了解优化函数的实现和工作机制。
在tensorflow中实现了很多种梯度优化的方法,包括自适应梯度优化算法(adamoptimizer),随机梯度下降算法SGD和简单梯度下降算法GD。各种优化器的差别在于随着训练过程学习率的调整策略不同。
自适应算法和sgd算法比较:在测试集上SGD算法得到的效果普遍要好于自适应算法,尽管有时候在训练集上自适应算法有时候会得到更小的loss;所以很多情况下依然会选择随机梯度下降算法来计算最终的模型。
Adamoptimizer初始化:
文章图片
此函数是Adam优化算法:是一个寻找全局最优点的优化算法,引入了二次方梯度校正。Adam 算法根据损失函数对每个参数的梯度的一阶矩估计和二阶矩估计动态调整针对于每个参数的学习速率。
参数解析:
learning_rate:控制了权重的更新比率(如 0.001)。较大的值(如 0.3)在学习率更新前会有更快的初始学习,而较小的值(如 1.0E-5)会令训练收敛到更好的性能。
beta1:一阶矩估计的指数衰减率,一阶矩是样本的平均值
beta2:二阶矩估计的指数衰减率,二阶矩是样本平方的平均值
epsilon:该参数是非常小的数,其为了防止在实现中除以零
use_locking:如果设置为true,在更新参数时使用锁操作
name:已经创建的adamoptimizer operation的名称
AdamOptimizer继承自Optimizer类,optimizer是基本的梯度优化类,该类不经常使用,使用最多的是该类的子类,AdamOptimizer继承了父类Optimizer的minimize接口,该接口添加操作节点,用于最小化loss,该函数是简单的合并了compute_gradients()与apply_gradients()函数返回为一个优化更新后的var_list;compute_gradients()主要实现计算loss的梯度,apply_gradients()主要是把计算出的梯度应用到变量上,实现变量的更新。有的时候,为了特殊目的,比如作梯度的修改,也可以调用以上两部,并把自己的对梯度的操作放在compute_gradients 与 apply_gradients 之间。
compute_gradients()和apply_gradients()的实现可以详细展开。
文章图片
参数解析:
Loss:是一个tensor,在该tensor中包含了需要最小化的值
Var_list:可以是一个list也可以是tuple元组,元素的类型是tf.variable,为了得到最小化的loss可以不断的更新该list或tuple内的变量,如果不指定默认是graph中graphKeys.TRAINABLE_VARIABLES变量收集器中的变量
Gate_gradients:如何控制梯度的计算,可以是`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`
Aggregation_method: 指定用于组合渐变项的方法,可用的值定义在AggregationMethod类中
Collocate_gradients_with_ops: 如果为True,尝试将渐变与相应的op进行对齐
Grad_loss:可选项,保存为loss计算出来的梯度
返回值:
【linux|tensorflow源码分析(七)-优化函数】返回一个元素是(gradient, variable)的list,variable是指摸个变量,gradient是该变量相对于loss的梯度;其中gradient是可以为None。
Compute_gradients主要通过调用gradients.gradients()接口实现梯度的计算,实现的过程就是求导。
文章图片
参数解析:
Grads_and_vars:该参数是compute_gradients()接口的返回值
Global_step:可选项,该参数可以记录variable更新的次数,每更新一次该值会+1
Name:可选项,可以指定apply_gradients()返回的operation的名字
返回值:
该接口返回可以应用gradients梯度的operation
Apply_gradients主要通过调用_create_slots(), _prepare(), _apply_dense(), and _apply_sparse()
推荐阅读
- Linux下面如何查看tomcat已经使用多少线程
- Beego打包部署到Linux
- Android事件传递源码分析
- Quartz|Quartz 源码解析(四) —— QuartzScheduler和Listener事件监听
- [源码解析]|[源码解析] NVIDIA HugeCTR,GPU版本参数服务器---(3)
- ffmpeg源码分析01(结构体)
- Linux|109 个实用 shell 脚本
- Java程序员阅读源码的小技巧,原来大牛都是这样读的,赶紧看看!
- linux定时任务contab
- 芯灵思SinlinxA33开发板Linux内核定时器编程