pytorch_hook机制
文章目录
-
- pytorch_hook机制
-
- 前言
- 一、hook简介
- 二、Tensor的hook机制
- 三、基于Module的hook机制
-
- register_forward_hook
- regitster_backward_hook
- 小结
- 示例代码
- 参考文献
前言
在理解hook机制之前,首先应该对pytorch张量的自动求导机制有所了解:PyTorch的自动求导机制详细解析,进而理解正向传播和反向传播的过程中程序在做什么。
在pytorch前向传播的过程中,会动态生成计算图;在反向传播过程中,对计算图中的每个模块的输入输出求解梯度,并把梯队回传到输出。在反向传播过程中为了减少内存消耗,会把过程中产生的梯度删除,仅保留计算图中叶子节点的梯度信息。但是,一些应用要求使用神经网络中间层输入输出的梯度值,如神经网络可解释性的CAM算法等。这时,使用hook机制就能帮助实现这个目标。
一、hook简介
hook机制主要非为两类:基于Tensor的hook机制,以及基于Module的hook机制。我的理解是,基于Tensor方便追踪某个特定张量的梯度,如某层的某个特定的输入;基于Module的hook机制则是在实际中应用比较多,可以帮助获得某层的输入输出的梯度,用于后续的计算。
无论使用哪种类型的hook机制,pytorch都要求我们注册一个hook,我的理解是,使用hook机制相当于一个钩子钩住了网络的前向传播或者反向传播,让用户可以在这中间添加一些操作(也就是调用一个能对前向、后向传播中间信息进行操作的函数)
二、Tensor的hook机制
这里使用一个简单的例子,参考代码见参考链接1。
使用**register_hook(hook)**注册一个钩子,也就是注册一个添加到计算图中间的函数。钩子函数使用的格式为:
hook(grad) -> Tensor or None
该例子使用一个简单的计算图进行计算。
x x x是随机生成一个3*1的tensor;
y = x + 3 y = x + 3 y=x+3
z = m e a n ( s u m ( y ) ) z=mean(sum(\sqrt{y})) z=mean(sum(y ?))
import torch
def print_grad(grad):
print('grad is \n',grad)x = torch.rand(3,1,requires_grad=True)
print('x value is \n',x)
y = x+3
print('y value is \n',y)
z = torch.mean(torch.pow(y, 1/2))
lr = 1e-3y.register_hook(print_grad)
z.backward() # 梯度求解
x.data -= lr*x.grad.data
print('new x is\n',x)
输出:
x value is
tensor([[0.5681],
[0.4868],
[0.9277]], requires_grad=True)
y value is
tensor([[3.5681],
[3.4868],
[3.9277]], grad_fn=)
grad is
tensor([[0.0882],
[0.0893],
[0.0841]])
new x is
tensor([[0.5680],
[0.4867],
[0.9276]], requires_grad=True)
在tensorx , y , z x,y,z x,y,z之间的函数关系定义好之后,计算图成功生成,中间变量的值被成功计算出来;在调用backward函数之前首先需要注册hook,否则hook就不会在backward的过程中被执行;最后,在反向传播过程中计算出来 y y y的梯度值并输出。
下面简单说明输出的结果:
输入为: x 1 , x 2 , x 3 x_1,x_2,x_3 x1?,x2?,x3?
第一次计算: y i = x i + 3 , i = 1 , 2 , 3 y_i=x_i+3,i=1,2,3 yi?=xi?+3,i=1,2,3,那么梯度为:KaTeX parse error: Undefined control sequence: \part at position 7: \frac{\?p?a?r?t?{y_i}}{\part{x_…
第二次计算: z = m e a n ( ∑ y i ) z=mean(\sum{\sqrt {y_i}}) z=mean(∑yi? ?),那么梯度为:KaTeX parse error: Undefined control sequence: \part at position 7: \frac{\?p?a?r?t?{z}}{\part{y_i}…
通过梯度传递原则,可以得到KaTeX parse error: Undefined control sequence: \part at position 7: \frac{\?p?a?r?t?{z}}{\part{x_i}…
借助上面的推导,可以很快自行验证上述结果的正确性。
三、基于Module的hook机制
与上面的基于tensor的hook机制原理相近,基于Module的hook机制是对模型的模块进行操作,比如神经网络的某个隐藏层。
有两种方法,用于前向传递的hook和用于后向传递的hook;
register_forward_hook 前向传递的hook主要用于在前向传播的过程中钩取模块之间的输入输出,使用**register_forward_hook(hook)**注册前向钩子,其中hook函数是一个如下形式的函数:
hook(module, input, output) -> None or modified output
regitster_backward_hook 后向传递的hook主要用于在后向传播的过程中钩取模块输入输出的梯度信息,使用**register_backward_hook(hook)**注册后向钩子,其中hook函数是一个如下形式的函数:
hook(module, grad_input, grad_output) -> Tensor or None
【pytorch学习记录|pytorch hook机制】如果使用register_backward_hook函数目前会报
"warning:在某些具有多个自动求导节点的场合,需要使用register_full_backward_hook获得完整的梯度信息(至于两者的区别,因为具体原理笔者还没有很理解,因此不多赘述。但是观察到如果对整个模型注册钩子,两个函数的输出结果不同)
Using a non-full backward hook when the forward contains multiple autograd Nodes"
小结 综上,我们只要能够理解register_forward_hook函数、register_backward_hook函数的不同用处即可,前向用于在前向传播过程中钩取模块输入输出值,后向用于在后向传播过程中钩取模块输入输出的梯度结果。
另外,事实上基于tensor的hook机制是类似于后向的hook机制的,hook函数是在后向传递的过程中被调用的。
示例代码 这里参考参考链接2中的代码,为了能够明显区分出前向和后向hook的区别进行了一定的修改,。
import torch
import torch.nn as nn
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class MyMean(nn.Module):# 自定义除法module
def forward(self, input):
out = input/4
return outdef tensor_hook(grad):
print('tensor hook')
print('grad:', grad)
return gradclass MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.f1 = nn.Linear(4, 1, bias=True)
self.f2 = MyMean()
self.weight_init()def forward(self, input):
self.input = input
output = self.f1(input) # 先进行运算1,后进行运算2
output = self.f2(output)
return outputdef weight_init(self):
self.f1.weight.data.fill_(8.0) # 这里设置Linear的权重为8
self.f1.bias.data.fill_(2.0) # 这里设置Linear的bias为2def my_backward_hook(self, module, grad_input, grad_output):
print('doing my_backward_hook')
print('original grad:', grad_input)
print('original outgrad:', grad_output)
return grad_inputdef my_forward_hook(self, module, input, output):
print('doing my_forward_hook')
print('input:', input)
print('output', output)if __name__ == '__main__':
input = torch.tensor([1, 2, 3, 4], dtype=torch.float32, requires_grad=True).to(device)net = MyNet()
net.to(device)
net.f1.register_forward_hook(net.my_forward_hook) # 这两个hook函数一定要result = net(input)执行前执行,因为hook函数实在forward的时候进行绑定的
net.f2.register_forward_hook(net.my_forward_hook)net.f1.register_full_backward_hook(net.my_backward_hook) # 这两个hook函数一定要result = net(input)执行前执行,因为hook函数实在forward的时候进行绑定的
net.f2.register_full_backward_hook(net.my_backward_hook)
input.register_hook(tensor_hook)print('forward now')
result = net(input)
print('result =', result)
print('over forward')print('\nbackward now')
result.backward()
print('over backward')print('\ninput.grad:', input.grad)
for param in net.parameters():
print('{}:grad->{}'.format(param, param.grad))
该网络定义了简单的单层全连接网络,为了方便理解,采用了固定的参数,因此可以很方便地证明代码的运行结果,这里不再赘述。
得到的输出是:
forward now
doing my_forward_hook
input: (tensor([1., 2., 3., 4.], grad_fn=),)
output tensor([82.], grad_fn=)
doing my_forward_hook
input: (tensor([82.], grad_fn=),)
output tensor([20.5000], grad_fn=)
result = tensor([20.5000], grad_fn=)
over forwardbackward now
doing my_backward_hook
original grad: (tensor([0.2500]),)
original outgrad: (tensor([1.]),)
doing my_backward_hook
original grad: (tensor([2., 2., 2., 2.]),)
original outgrad: (tensor([0.2500]),)
tensor hook
grad: tensor([2., 2., 2., 2.])
over backwardinput.grad: tensor([2., 2., 2., 2.])
Parameter containing:
tensor([[8., 8., 8., 8.]], requires_grad=True):grad->tensor([[0.2500, 0.5000, 0.7500, 1.0000]])
Parameter containing:
tensor([2.], requires_grad=True):grad->tensor([0.2500])
对于结果值得注意的我认为有两点:
- 注意hook函数调用的位置,两者分别在forward和backward的过程中调用;
- 注册钩子需要在模型前向传播之前,因为hook是在前向传播的过程中链接上hook的。
参考文献
参考链接1:register_farward_hook
参考链接2:register_backward_hook
推荐阅读
- 计算机视觉|OpenCV之图像轮廓(绘制图像轮廓)
- python|Real-Time High-Resolution Background Matting翻译
- python|深度学习理论基础
- 计算机视觉|YoloV5建立自己的数据集并进行训练
- Python|【Python百日基础系列】Day03 - Python 数据类型
- spring|Java最全面试题之Spring篇
- 本书适合Python 程序员、数据分析人员《Python机器学习实践指南》(好书分享更新中)
- 分享|python基础-零基础入门到精通
- 深度学习|【延伸阅读】让老照片重现光彩(五)(Pix2PixHD模型源代码+中文注释)