深度学习|一文理解深度学习框架中的InstanceNorm

深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

撰文|梁德澎

本文首发于公众号GiantPandaCV
本文主要推导 InstanceNorm 关于输入和参数的梯度公式,同时还会结合 PyTorch 和 MXNet 里的 InstanceNorm 代码来分析。

1 InstanceNorm 与 BatchNorm 的联系 对一个形状为 (N, C, H, W) 的张量应用 InstanceNorm[4] 操作,其实等价于先把该张量 reshape 为 (1, N * C, H, W)的张量,然后应用 BatchNorm[5] 操作。而 gamma 和 beta 参数的每个通道所对应输入张量的位置都是一致的。
而 InstanceNorm 与 BatchNorm 不同的地方在于:

  • InstanceNorm 训练与预测阶段行为一致,都是利用当前 batch 的均值和方差计算
  • BatchNorm 训练阶段利用当前 batch 的均值和方差,测试阶段则利用训练阶段通过移动平均统计的均值和方差
论文[6]中的一张示意图,就很好地解释了两者的联系:
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

https://arxiv.org/pdf/1803.08494.pdf 所以 InstanceNorm 对于输入梯度和参数求导过程与 BatchNorm 类似,下面开始进入正题。 2 梯度推导过程详解 在开始推导梯度公式之前,首先约定输入,参数,输出等符号:
  • 输入张量 , 形状为(N, C, H, W),rehape 为 (1, N * C, M) 其中 M=H*W
  • 参数 ,形状为 (1, C, 1, 1),每个通道值对应 N*M 个输入,在计算的时候首先通过在第0维 repeat N 次再 reshape 成 (1, N*C, 1, 1)
  • 参数 ,形状为 (1, C, 1, 1),每个通道值对应 N*M 个输入,在计算的时候首先通过在第0维 repeat N 次再 reshape 成 (1, N*C, 1, 1)
而输入张量 reshape 成 (1, N * C, M)之后,每个通道上是一个长度为 M 的向量,这些向量之间的计算是不像干的,每个向量计算自己的 normalize 结果。所以求导也是各自独立。因此下面的均值、方差符号约定和求导也只关注于其中一个向量,其他通道上的向量计算都是一样的。
  • 一个向量上的均值
  • 一个向量上的方差
  • 一个向量上一个点的 normalize 中间输出
  • 一个向量上一个点的 normalize 最终输出 ,其中和表示这个向量所对应的 gamma 和 beta 参数的通道值。
  • loss 函数的符号约定为
gamma 和 beta 参数梯度的推导

先计算简单的部分,求 loss 对和的偏导:
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

其中表示 gamma 和 beta 参数的第个通道参与了哪些 batch 上向量的 normalize 计算。

因为 gamma 和 beta 上的每个通道的参数都参与了 N 个 batch 上 M 个元素 normalize 的计算,所以对每个通道进行求导的时候,需要把所有涉及到的位置的梯度都累加在一起。
对于在具体实现的时候,就是对应输出梯度的值,也就是从上一层回传回来的梯度值。
输入梯度的推导
对输入梯度的求导是最复杂的,下面的推导都是求 loss 相对于输入张量上的一个点上的梯度,而因为上文已知,每个长度是 M 的向量的计算都是独立的,所以下文也是描述其中一个向量上一个点的梯度公式。具体是计算的时候,是通过向量操作(比如 numpy)来完成所有点的梯度计算。
先看 loss 函数对于的求导:
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

而从上文约定的公式可知,对于
402 Payment Required的计算中涉及到的有三部分,分别是、和。所以 loss 对于的偏导可以写成以下的形式: 深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

接下来就是,分别求上面式子最后三项的梯度公式。
第一项梯度推导
在求第一项的时候,把和看做常量,则有:
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

然后有:
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

最后可得第一项梯度公式:
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

第三项梯度推导
接着先看第三项梯度深度学习|一文理解深度学习框架中的InstanceNorm
文章图片
,因为第三项的推导形式简单一些。
先计算上式最后一项 ,把看做常量:
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

然后计算深度学习|一文理解深度学习框架中的InstanceNorm
文章图片
,等价于求 。而因为每个长度是 M 的向量都会计算一个方差 ,而计算出来的方差又会参数到所有 M 个元素的 normalize 的计算,所以 loss 对于的偏导需要把所有 M 个位置的梯度累加,所以有:
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

接着计算
402 Payment Required ,
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

最后可得:
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

第二项梯度推导

最后计算第二项的梯度深度学习|一文理解深度学习框架中的InstanceNorm
文章图片
,一样先计算最后一项 :
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

接着计算深度学习|一文理解深度学习框架中的InstanceNorm
文章图片
,等价于是求 。而因为每个长度是 M 的向量都会计算一个均值 ,而计算出来的均值又会参与到所有 M 个元素的 normalize 的计算,所以 loss 对于的偏导需要把所有 M 个位置的梯度累加,所以有:
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

接着计算 ,
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

最后可得:
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

输入梯度最终的公式
分别计算完上面三项,就能得到对于输入张量每个位置上梯度的最终公式了:
深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

观察上式可以发现,loss 对的求导公式包括了 loss 对求导的公式,所以这也是为什么先计算第三项的原因,在下面代码实现上也可以体现。

而在具体实现的时候就是直接套公式计算就可以了,下面来看下在 PyTroch 和 MXNet 框架中对 InstanceNorm 的实现。
3 深度学习框架实现代码解读 PyTroch 前向传播实现
前向传播代码链接:
https://github.com/pytorch/pytorch/blob/fa153184c8f70259337777a1fd1d803c7325f758/aten%2Fsrc%2FATen%2Fnative%2FNormalization.cpp#L506
为了可读性简化了些代码:
Tensor instance_norm( const Tensor& input, const Tensor& weight/* optional */, const Tensor& bias/* optional */, const Tensor& running_mean/* optional */, const Tensor& running_var/* optional */, bool use_input_stats, double momentum, double eps, bool cudnn_enabled) { // ...... std::vector shape = input.sizes().vec(); int64_t b = input.size(0); int64_t c = input.size(1); // shape 从 (b, c, h, w) // 变为 (1, b*c, h, w) shape[1] = b * c; shape[0] = 1; // repeat_if_defined 的解释见下文 Tensor weight_ = repeat_if_defined(weight, b); Tensor bias_ = repeat_if_defined(bias, b); Tensor running_mean_ = repeat_if_defined(running_mean, b); Tensor running_var_ = repeat_if_defined(running_var, b); // 改变输入张量的形状 auto input_reshaped = input.contiguous().view(shape); // 计算实际调用的是 batchnorm 的实现 // 所以可以理解为什么 pytroch // 前端 InstanceNorm2d 的接口 // 与 BatchNorm2d 的接口一样 auto out = at::batch_norm( input_reshaped, weight_, bias_, running_mean_, running_var_, use_input_stats, momentum, eps, cudnn_enabled); // ...... return out.view(input.sizes()); }

repeat_if_defined 的代码:
https://github.com/pytorch/pytorch/blob/fa153184c8f70259337777a1fd1d803c7325f758/aten%2Fsrc%2FATen%2Fnative%2FNormalization.cpp#L27
static inline Tensor repeat_if_defined( const Tensor& t, int64_t repeat) { if (t.defined()) { // 把 tensor 按第0维度复制 repeat 次 return t.repeat(repeat); } return t; }

从 pytorch 前向传播的实现上看,验证了本文开头说的关于 InstanceNorm 与 BatchNorm 的联系。还有对于参数 gamma 与 beta 的处理方式。
MXNet 反向传播实现
因为我个人感觉 MXNet InstanceNorm 的反向传播实现很直观,所以选择解读其实现:
https://github.com/apache/incubator-mxnet/blob/4a7282f104590023d846f505527fd0d490b65509/src%2Foperator%2Finstance_norm-inl.h#L112
同样为了可读性简化了些代码:
template void InstanceNormBackward( const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { using namespace mshadow; using namespace mshadow::expr; // ...... const InstanceNormParam& param = nnvm::get( attrs.parsed); Stream *s = ctx.get_stream(); // 获取输入张量的形状 mxnet::TShape dshape = inputs[3].shape_; // ...... int n = inputs[3].size(0); int c = inputs[3].size(1); // rest_dim 就等于上文的 M int rest_dim = static_cast( inputs[3].Size() / n / c); Shape<2> s2 = Shape2(n * c, rest_dim); Shape<3> s3 = Shape3(n, c, rest_dim); // scale 就等于上文的 1/M const real_t scale = static_cast(1) / static_cast(rest_dim); // 获取输入张量 Tensor data = https://www.it610.com/article/inputs[3] .get_with_shape(s2, s); // 保存输入梯度 Tensor gdata = https://www.it610.com/article/outputs[kData] .get_with_shape(s2, s); // 获取参数 gamma Tensor gamma = inputs[4].get(s); // 保存参数 gamma 梯度计算结果 Tensor ggamma = outputs[kGamma] .get(s); // 保存参数 beta 梯度计算结果 Tensor gbeta = outputs[kBeta] .get(s); // 获取输出梯度 Tensor gout = inputs[0] .get_with_shape( s2, s); // 获取前向计算好的均值和方差 Tensor var = inputs[2].FlatTo1D(s); Tensor mean = inputs[1].FlatTo1D(s); // 临时空间 Tensor workspace = //..... // 保存均值的梯度 Tensor gmean = workspace[0]; // 保存方差的梯度 Tensor gvar = workspace[1]; Tensor tmp = workspace[2]; // 计算方差的梯度, // 对应上文输入梯度公式的第三项 // gout 对应输出梯度 gvar = sumall_except_dim<0>( (gout * broadcast<0>( reshape(repmat(gamma, n), Shape1(n * c)), data.shape_)) * (data - broadcast<0>( mean, data.shape_)) * -0.5f * F( broadcast<0>( var + param.eps, data.shape_), -1.5f) ); // 计算均值的梯度, // 对应上文输入梯度公式的第二项 gmean = sumall_except_dim<0>( gout * broadcast<0>( reshape(repmat(gamma, n), Shape1(n * c)), data.shape_)); gmean *= -1.0f / F( var + param.eps); tmp = scale * sumall_except_dim<0>( -2.0f * (data - broadcast<0>( mean, data.shape_))); tmp *= gvar; gmean += tmp; // 计算 beta 的梯度 // 记得s3 = Shape3(n, c, rest_dim) // 那么swapaxis<1, 0>(reshape(gout, s3)) // 就表示首先把输出梯度 reshape 成 // (n, c, rest_dim),接着交换第0和1维度 // (c, n, rest_dim),最后求除了第0维度 // 之外其他维度的和, // 也就和 beta 的求导公式对应上了 Assign(gbeta, req[kBeta], sumall_except_dim<0>( swapaxis<1, 0>(reshape(gout, s3)))); // 计算 gamma 的梯度 // swapaxis<1, 0> 的作用与上面 beta 一样 Assign(ggamma, req[kGamma], sumall_except_dim<0>( swapaxis<1, 0>( reshape(gout * (data - broadcast<0>(mean, data.shape_)) / F( broadcast<0>( var + param.eps, data.shape_ ) ), s3 ) ) ) ); // 计算输入的梯度, // 对应上文输入梯度公式三项的相加 Assign(gdata, req[kData], (gout * broadcast<0>( reshape(repmat(gamma, n), Shape1(n * c)), data.shape_)) * broadcast<0>(1.0f / F( var + param.eps), data.shape_) + broadcast<0>(gvar, data.shape_) * scale * 2.0f * (data - broadcast<0>( mean, data.shape_)) + broadcast<0>(gmean, data.shape_) * scale); }

可以看到基于 mshadow 模板库的反向传播实现,看起来很直观,基本是和公式能对应上的。
4 InstanceNorm numpy 实现 最后看下 InstanceNorm 前向计算与求输入梯度的 numpy 实现:
import numpy as np import torcheps = 1e-05 batch = 4 channel = 2 height = 32 width = 32input = np.random.random( size=(batch, channel, height, width)).astype(np.float32) # gamma 初始化为1 # beta 初始化为0,所以忽略了 gamma = np.ones((1, channel, 1, 1), dtype=np.float32) # 随机生成输出梯度 gout = np.random.random( size=(batch, channel, height, width))\ .astype(np.float32)# 用numpy计算前向的结果 mean_np = np.mean( input, axis=(2, 3), keepdims=True) in_sub_mean = input - mean_np var_np = np.mean( np.square(in_sub_mean), axis=(2, 3), keepdims=True) invar_np = 1.0 / np.sqrt(var_np + eps) out_np = in_sub_mean * invar_np * gamma# 用numpy计算输入梯度 scale = 1.0 / (height * width) # 对应输入梯度公式第三项 gvar = gout * gamma * in_sub_mean * -0.5 * np.power(var_np + eps, -1.5) gvar = np.sum(gvar, axis=(2, 3), keepdims=True)# 对应输入梯度公式第二项 gmean = np.sum( gout * gamma, axis=(2, 3), keepdims=True) gmean *= -invar_np tmp = scale * np.sum(-2.0 * in_sub_mean, axis=(2, 3), keepdims=True) gmean += tmp * gvar# 对应输入梯度公式三项之和 gin_np = gout * gamma * invar_np + gvar * scale * 2.0 * in_sub_mean + gmean * scale# pytorch 的实现 p_input_tensor = torch.tensor(input, requires_grad=True) trans = torch.nn.InstanceNorm2d( channel, affine=True, eps=eps) p_output_tensor = trans(p_input_tensor) p_output_tensor.backward( torch.Tensor(gout))# 与 pytorch 对比结果 print(np.allclose(out_np, p_output_tensor.detach().numpy(), atol=1e-5)) print(np.allclose(gin_np, p_input_tensor.grad.numpy(), atol=1e-5))# 命令行输出 # True # True

本文对于 InstanceNorm 的梯度公式推导大部分参考了博客[1][2]的内容,然后在参考博客的基础上,按自己的理解具体推导了一遍,很多时候是从结果往回推,如果有什么疑惑或意见,欢迎交流。
参考资料
[1]https://medium.com/@drsealks/batch-normalisation-formulas-derivation-253df5b75220
[2]https://kevinzakka.github.io/2016/09/14/batch_normalization/
[3]https://www.zhihu.com/question/68730628
[4]https://arxiv.org/pdf/1607.08022.pdf
【深度学习|一文理解深度学习框架中的InstanceNorm】[5]https://arxiv.org/pdf/1502.03167v3.pdf
[6]https://arxiv.org/pdf/1803.08494.pdf
其他人都在看
  • 一个黑客“沦落”为搬砖的CVer

  • 岁末年初,为你打包了一份技术合订本
  • 一文轻松掌握深度学习框架中的einsum
  • 对抗软件系统复杂性:恰当分层,不多不少
  • 计算机史最疯狂一幕:“蓝色巨人”奋身一跃
  • 30年做成三家独角兽公司,AI芯片创业的底层逻辑

点击“阅读原文”,欢迎下载体验OneFlow新一代开源深度学习框架

深度学习|一文理解深度学习框架中的InstanceNorm
文章图片

    推荐阅读