分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练

在 3.25 日的 MegEngine Meetup 中,旷视研究院周亦庄讲师分享了《利用 MegEngine 分布式通信算子实现复杂的并行训练》。
直播回放链接:利用 MegEngine 的分布式通信算子实现复杂的并行训练 - MegEngine Meetup No.2_哔哩哔哩 (゜-゜)つロ 干杯~-bilibili
分享内容主要分为四个部分:
1. 介绍 MegEngine 的分布式通信算子;
2. 简单参数并行,用于熟悉模型并行的一些基本概念;
3. 层内模型并行;
4. 层间模型并行和流水线并行,同时介绍了如何实现一个简单的 GPipe。
以下为该分享的文字实录,Enjoy~
一、背景 并行训练是开展深度学习研究和业务非常重要的一环,很多基础研究都需要大规模的计算集群甚至是超级计算机来完成。比如,像我们知道的 DeepMind 下围棋的 AlphaGo,还有OpenAI 的 1750 亿(175 billion)参数的超大语言模型 GPT-3,最近 OpenAI 还搞了一个 CLIP 和 DALL-E,他们都是用非常大的集群来进行分布式训练的。而因为旷视研究院有 Brain++ 这个分布式的计算平台,所以我们也有很多优秀的成果。大模型在各类视觉和语言任务上相比于小模型都有显著优势,所以最近的一种趋势是模型规模、数据规模越大越好,“大即正义”,因此更需要大规模的并行训练。
并行训练,一方面可以调动上百甚至上千块 GPU(图形处理器,又称”显卡”,简称”卡”,是深度学习最常见的计算设备)进行训练,第二部分也是根据业务或模型的特点,我们可以设计出最高效的并行模式。这是我今天讲的并行训练的一个现实意义。
先来讲一下深度学习当中有三种比较常用的并行模式,三种并行模式的关系用下面这张图就可以表达清楚。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

第一种(层内模型并行),是利用矩阵乘法天然的并行特性,把每层(比如全连接层或卷积层)内部的矩阵乘法计算给拆开,表现为沿着输入/输出 通道(channel) 拆开进行分组计算,这就叫层内模型并行。
第二种(层间模型并行),是利用神经网络串行执行的特性,把网络按照执行顺序拆开,分别放到不同的设备上进行计算,比如说我们一个 ResNet18,它有 17 层卷积层加上最后一层全连接层,如果我们把前九层的和后九层的计算放到两块卡(即 GPU/显卡)上,它就是叫层间模型并行。层间与层内这两种模型并行方式是“正交”的,互不影响,可以同时存在。
以上说的两种并行,它的模型参数都是拆开来的,每个计算节点(计算节点是底层计算设备的一种抽象,它可以是一张卡,也可以是一台或者一组 8 卡机,即装载 8 块 GPU 的计算机)只负责管理整个网络的一部分参数以及这部分参数参与的相应计算。
最后一种就是我们最常用的数据并行,它又是另外一个维度,在数据并行维度上,模型参数都是共享的,但是接收的数据是不一样的。通过增加计算设备,我们可以近似线性地增加单次迭代的 batch size(批量,即训练图片的数量),从而节省训练模型的时间。
这三种并行维度是两两正交的,意思是在实际训练中我们既会用到两种模型并行也会用到数据并行。小模型可能数据并行就足够了,但大模型由于参数特别多、计算量非常大,计算难以用单个 GPU 完成,这时候就要将计算拆解到不同 GPU 上,此即模型并行。
二、MegEngine 的通信算子 接下来,进入到今天要讲的正题。先说通信算子。
人类的历史它其实就是一个信息交互的历史,也就是一个通信的历史——人与人之间说话就是通信,我今天做直播,它其实也是通信,我把信息广播给大家,这也是通信,电视和广播当然也是通信。
对于深度学习框架来说,通信是最重要的功能之一,否则数据并行和模型并行难以实现。简单来说就是我有很多个计算设备(GPU),我需要让信息在所有计算设备之间进行交互,那就需要集合通信——集合通信是一个求导完备的一套通信规则。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

表中列了有 8 种集合通信算子和 2 种点对点通信算子,这就是 MegEngine 全部的通信算子。8 种集合通信算子,构成一套求导完备的通信的规则,它们互相各自为导数。MegEngine 提供了对通信算子的自动求导,所以和其它所有用于计算的算子(如卷积、ReLU、转置等)一样,我们可以自由地把通信算子加入前向计算图,框架将负责对其求导。
考虑到有些同学没有背景知识,我们一一介绍一下集合通信算子的功能。
Broadcast
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

Broadcast 即广播。
它表示的是数据的一个同步的过程,将一张 GPU 上的信息同步给其它所有 GPU。这在数据并行中非常有用,因为数据并行的话,每张卡上面的参数应该确保都是一样的,因此在初始化时我们会通过 Broadcast 进行参数同步,我们也会周期性同步一些缓存信息(buffer,比如 BatchNorm 的统计量)。
ReduceSum 分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

第二个是 ReduceSum。 ReduceSum 叫做求和或者归约,将所有 GPU 上的数据收集到一个 GPU 上并相加。
我们刚才讲的 Broadcast 和 ReduceSum 这两个通信算子是构成参数服务器 Parameter Server 的一个基石,它是中心式的,在这里面 GPU0 就起到一个中心的作用,我先把中心参数通过 Broadcast 同步给各张卡进行前传,反传后通过Reduce收集各张卡的梯度,进行参数更新。Broadcast 和 ReduceSum ,互为导数的, ReduceSum 的导数就是 Broadcast,Broadcast 的导数是 ReduceSum 。
AllReduce 分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

我们再介绍 AllReduce,本来 Reduce 是归约到一张卡上,AllReduce 则是归约到每一张卡上。它即可以理解为 Reduce Broadcast 的组合,即我先 Reduce 到一张卡,然后再 Rroadcast 到所有卡;也可以理解为每张卡都同时调用了 Reduce,AllReduce 它的导数就是 AllReduce 本身。
尽管只用 Reduce 和 Broadcast 就可以实现 AllReduce,但是 AllReduce 的高效实现(即 Ring-AllReduce)才是构成现代 分布式深度学习框架的基石,它的通信时间基本不随 GPU 数量的增加而增加,因此可以高效地实现分布式训练的规模化。在数据并行中,我们用 AllReduce 将所有梯度求和,并用于模型参数更新。
Gather
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

Gather 简单来说就是把每张卡上不同的信息都给收集过来,并沿着第一维相连(Concatenate)。
AllGather 分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

AllGather 就是全收集,和 AllReduce 类似,它可以理解为 Gather 后接 Broadcast。
AllGather 是我们层内模型并行当中的一个很重要的操作,因为你的参数在不同的卡上,你的数据也在不同的卡上,在进行模型并行的时候,我需要把数据或者参数都收集起来放到一张卡上才能进行接下来的计算,这就是 AllGather 的作用。
AllToAll
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

AllToAll 也是层内模型并行中经常用到的一个操作,特别是在模型并行和数据并行进行切换的时候,它本质上对一个矩阵进行了转置,我们后面在具体应用中会进一步说明。AllToAll 的导数是它本身。
Scatter 和 ReduceScatter
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

最后 Scatter 和 ReduceScatter 合起来讲,Scatter 就是分发,它将一张卡上的数据拆分给各张卡,它和 Gather 互为导数。
ReduceScatter 可以理解为在分发之前先进行了求和,它和 AllGather 互为导数。
三、简单参数并行 介绍完 MegEngine 的通信算子,我们来了解它们如何使用。首先,让我们从简单参数并行开始,它只涉及 AllGather 这一通信算子。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

简单参数并行是怎么一回事?我们先用一个简单的全连接层(即矩阵乘法)来回顾一下数据并行——数据并行中,W 是我们的模型(即我们的权重 weight,每张卡拥有一份同样的拷贝),x 是数据。数据并行要求我们将数据平均拆分到每张卡上,2 卡拆 2 份,即 x0 和 x1,4 卡则拆成 4 份,依此类推,各张卡分别进行矩阵乘法计算,得到对应的结果 y。
简单参数并行本质是数据并行的优化?我们不必在每张卡上都放完整的模型,而是只放部分模型,只有在我们需要(即前传)的时候,把分散在各张卡上的参数收集(AllGather)起来参与计算。
如何实现?我们在做矩阵乘法操作之前,先对参数进行 AllGather,从各个节点上收集被我们拆开的参数,AllGather 以后每张卡都有全部的权重了,计算就变得和数据并行一模一样的。所以,简单参数并行的核心操作就是 AllGather,本质用通信来节省显存。
为什么能节省显存呢?我们现在把整个求导过程也画出来了,我们知道在训练一份参数的时候,它其实是会占掉三份显存——参数一份,梯度一份,优化器的 momentum 一份,所以一个参数量 1 million 的模型,如果我们使用数据并行,会占用 3 * 4G = 12G(1 million fp32 类型的数据占用 4G)的显存,那我们一张 2080ti 就完全没有显存可以用于训练了。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

我们再来研究一下这张图,我们在前传的时候做了一次 AllGather,在反传的时候,我们知道 AllGather 的导数是 ReduceScatter,所以,它反传的时候会进行一次 ReduceScatter 。这和数据并行不一样,数据并行前传不需要通信,反传需要进行 AllReduce这是他们的区别。
我们用 MegEngine 写了一套数据并行和简单参数并行的代码,它们有三个不同:
  • 一个不同是它们的前传是不一样的——右边(简单参数并行)就是要做一次 AllGather ;
  • 还有一个不同就是他们在参数初始化的时候。在数据并行中我们需要参数同步,所以我们要 Broadcast,但是在简单参数并行里面,我们需要的是参数分发,所以用 Scatter,就把它们给分发出去。
  • 最后一个不同就是在求导的时候,求导的时候在数据并行当中我们需要进行 AllReduce(MegEngine 使用 AllReduce callback 来支持数据并行),但是在简单参数并行里面不需要进行 AllReduce,自动微分器会负责反传时正确调用 ReduceScatter。
【分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练】分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

四、层内模型并行 层内模型并行在原理上更加复杂。我们刚才讲的参数并行,它其实是一种层内模型并行的一种特例,因为它非常的简单,只需要对参数进行 AllGather。实际上我们的层内模型并行还有多种不一样的实现。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

上图给出了完整的矩阵乘法、数据并行和两种模型并行的实现。
我们知道矩阵乘和卷积神经网络中的卷积层(卷积层可以视为对 channel 维度进行的矩阵乘),都天然具有并行的特性。我们在数学意义上的矩阵乘法,每一行每一列的运算都可以独立进行,数据并行就充分的利用了这个特性,我们把数据进行平均切分,各自放在不同的设备上各自做矩阵乘法,最后可以合并起来得到完整结果。
在层内模型并行当中,我们是把每层(全连接/卷积层)的参数矩阵 W 进行切分。一种方式是按输出维度进行切分(纵切)。第二种种类是按输入维度进行切分(横切)。前者在每张卡上得到部分输出维度的对应结果;后者利用了矩阵的低秩特性(Low Rank),每张卡的结果是最终结果的低秩分量,后续须通过 AllReduce 或者 ReduceScatter 将其求和。
接下来我们在多层神经网络中应用层内模型并行——我们实现纯粹的层内模型并行,或者和数据并行搭配使用,完成混合并行。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

上图第一行是纯数据并行。数据在一开始就被切分到各张卡上,之后不需要进行交换或信息交流,因此数据并行后接数据并行不需要进行特殊操作。
第二种纯层内模型并行。首先你需要完整样本数(batch)的输入特征“X”,最后矩阵乘出来它是完整样本数但部分输出通道数(channel)的特征“Y”,为了后续继续进行模型并行的矩阵乘法,我必须做一次 AllGather,把“Y”沿着通道(channel)收集起来,把它再变成样本数和通道数皆完整的“Y”,再与模型并行的“V”相乘。如果网络继续加深,那么每次矩阵乘结束都要进行 AllGather 操作。
第三种混合并行混合了数据并行与层内模型并行。我们还是以模型并行开始,模型并行的全连接层输出一个纵切的“Y”(即沿输出通道切分的特征 Tensor),但是我们数据并行要的是横切的“Y”(即沿样本数维度切分的特征 Tensor),应该怎么操作?在介绍 MegEngine 通信算子的时候我们提到一个转置操作叫 AllToAll,它可以直接把这个纵切的“Y”变成了横切的“Y”。接下来我们就可以恢复数据并行了,进行一次数据并行的矩阵乘法后,我们还想进行一次模型并行的矩阵乘法,那就再做一次 AllGather,得到全部样本数且全部通道数的完整特征 Tensor。掌握了利用 AllToAll 和 AllGather 实现的“切换”以后,你就可以自己设计与训练混合并行的模型。
接下来我们举例两个应用场景。
场景一:全连接的层内模型并行
我们来进入一个具体场景,在人脸识别任务中应用全连接的层内模型并行。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

在人脸识别任务当中,可能有百万、千万的 ID(Identity,同一个人为一个 ID),相当于要去做一个输出维度为百万/千万的分类任务,所以,最后这一层,分类的这一层 FC 层(全连接层)它可能参数特别大,比如说我们有一百万(1 million)的 ID,提取的人脸特征是一个 1024 维的向量,它们乘起来就会占用 4 个 G 显存,我们刚才提到 4G 参数的模型在实际训练中会固定占用 3 倍显存,就是 12G,一般的显卡装不下。我只能把这个全连接给放到各张卡上,如果我们有 8 张卡,每张卡就只会分到 1.5G,那么还是可以接受的。这个场景的特点是什么?就是人脸特征维度相比于我的参数矩阵其实非常小的,所以我们对数据进行通信(AllGather),它的代价要比对权重进行通信(AllReduce)它的代价小得多,所以在这个场景下特别适合做模型并行。
在模型并行下分类器 W 输出的结果 Y 的具体含义是什么?我们知道 Y 是竖着切分的,竖着这一维是样本(batch)维,就是它有多少个训练的样本,横着的这一维其实是 ID 维度,就是类别维,表示样本属于各个 ID 的概率,而模型并行下它只输出了一部分标签的概率。求损失函数的时候我们往往用交叉熵(CrossEntropy),交叉熵需要全部的类别概率。没错,利用之前我们介绍的 AllToAll 算子,我们把输出的模型并行的概率矩阵给进行 AllToAll转置,它就变回了数据并行的格式。(讲师注:实际上你并不需要进行 AllToAll,在分类任务的特殊场景下,你并不需要 AllToAll,因为通信代价很大,你可以籍由两次极低代价的通信来实现交叉熵的计算,但是这个超纲了,但不是很困难,留给大家当思考题。)
我们直接上代码。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

整个过程中有三步,第一步是 AllGather,第二步进行矩阵乘,第三步进行 AllToAll。
那么上图框起来的这段代码是什么东西呢?我们做了这么多 reshape,什么 transpose——这叫数据重排布,我们再花 5 分钟的时间来讲一下数据重排布是什么。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

我们 AllToAll 做完以后,得到的其实并不是我们想要的部分数据加上全部分类的一个结果,它其实在底层的数据排布(layout)上面它不是我们期望的。上图是 1 个简化版本的例子,它的分类从 0-7 总共有 8 类,它的样本是 4 张人脸图片。经过模型并行,在卡 0 上面我们得到的输出是 0-3 类的结果,卡 1 上面得到的是 4-7 类的结果。我们做完 AllToAll 以后它变成的矩阵(0,1,2,3,10,11,12,13)并不是我们想要的,我们最后想要的就是 0,1,2,3,4,5,6,7,下面是 10-17,所以的话我们必须先做一次 reshape,沿着这个方向是最里面维 0,1,2,3 数据是连续的,我们把这外面两维(0,10,4,14)个给进行一次转置,就是转过来,最后 reshape 为想要的结果。为了以后使用方便,我简单进行了以下两个封装,上面封装叫 mp2dp,就是从模型并行变成数据并行(Data Parallelism)的一个封装,下面这个是 dp2mp,有了这两个封装以后,我们上面的前传代码就变得简单了。
场景二:组卷积模型并行
讲完了全连接,接下来我们再讲组卷积(Group Convolution),
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

Group Convolution 在我们的移动端模型上面特别常见,组卷积和普通卷积它的区别就在于组卷积相当于 K 个普通卷积。比如说你有三组,就相当于三个普通卷积,但是每个普通卷积都比自己的小,你们也可以发现这个是天然并行的,上图红色的、绿色的、黄色其实可以各自做,在不同的设备上做。
下图用之前二维的表示抽象一下卷积和组卷积的不同——组卷积的模型,它和卷积不一样,组卷积相当于一个稀疏的矩阵乘法,它不是一个稠密的的矩阵(dense matrix)。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

数据并行情况下和普通卷积一样,我们把数据进行切分;模型并行我们可以直接按颜色把这三个组分开,我们第一块卡上做第一个组,第二块卡上做第二个组,第三块卡上做第三个组,对于每块卡来说,原本的组卷积计算都变成了普通的卷积操作。
如果我们前面是普通卷积,中间要插入一组模型并行的组卷积,我们应该怎么样从这两种数据排布之间切换?
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

很简单,我们就做一次数据重排布(即 AllToAll),由于是数据并行到模型并行,所以我们调用transpose_dp2mp。
如果我们有多个组卷积,他们连在一起,实际上我们并不需要反复地在数据和模型并行间切换,我们只需要关注头和尾。所以,我们的组卷积在前传函数里面有一个叫 is_head 和 is_tail,我们 is_head 的时候,我们做一次通信, is_tail 的时候再做一次通信,我们中间就完全不需要通信了。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

五、层间模型并行 我们进入层间模型并行,刚才的层内模型并行我们介绍了相关原理和应用(全连接和组卷积)。层间模型并行和层内模型并行很不一样,主要就是简单模型并行和流水线并行。层间模型并行简单来说就是把网络的前半部分、中间部分和后半部分分开(甚至分成更多份),就像一条鱼,鱼头、鱼中和鱼尾。
我们简单来看一下数据并行和层间模型并行的对比示意图。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

数据并行就是把数据切开,层间模型并行不切数据,而是把模型的前半部分和后半部分给拆分到不同的 GPU 上,这边就涉及到一个问题,怎么把“Y”第一块 GPU 的输出结果,给“放”到第二块 GPU 上,这里面就需要 send 操作。MegEngine 提供了八个集合通信算子,加上两个点对点通信算子——一个就是 send,一个就是 receive。这两个算子组成了层间模型并行的核心操作,接下来主要讲 send receive。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

如果层间模型并行,我们用一个图表来抽象的话(如上图下半部分),横轴是计算时间,随着计算推进,纵轴是我们的计算设备(GPU),我们发现任务之间存在依赖关系,所以 GPU 0 算完后必须做 send 操作,同时卡 1 做 receive 接收卡 0 的结果,然后进行自己的计算,算完再 send,卡 2 receive……这样才能做完一个流程。
为了方便起见,我们这边又做了一次封装,第一个函数是把我们出来的计算结果给发到下一个 GPU,这个函数是下一块 CPU 调用的,就是它从上一个 GPU 去给它拿出去,MegEngine 自带的 recv 不带自动的形状和类型推导(讲师注:在 MegEngine 的下个版本即将支持),因此封装的时候我也简单实现了一下。
简单模型并行 分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

我们直接看代码,在普通的数据并行里面,这是一个简单的 ResNet 18的模型,它总共有 17 层卷积加上一层全连接,在简单模型并行里面,如果它是第 1 块 GPU,它就负责第一部分的 5 层卷积,第2第 3 块各负责 4 层卷积,最后一块 GPU 负责 4 层卷积和最后的一层全链接。
在前传的时候先进行判断——当我们如果不是第 1 块 GPU 的话,我们就从前面一块卡拿数据。之后进行自己负责的卷积计算。得到结果后再次进行判断——如果不是最后一块 GPU,我们要把我的数据给送到下一块 GPU 上,如果是最后一块,就直接 return。
我们可以用代码来展示简单模型并行的推理和训练的结果:
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

在推理过程中,输入一张组(32张) 224分 辨率的图片,前三块 GPU 输出的都是网络的中间特征,最后的 GPU 输出的是网络的预测值。在训练当中值得一提的:第一,因为是模型并行,所以我们不需要进行 AllReduce;第二,前三块 GPU 在调用 gm.backward 时传入了一个 None,其实我们在设计 API 的时候,backward 任何东西都可以,backward None 在这里会发生什么?由于前传有一个 send,所以自动微分的时候就会插入一个 recv,它会先等待来自下游的梯度,然后进行正常的反传。
流水线并行 我们接下来讲流水线并行。简单的模型并行需要算完同一批次的全部的数据再给下一个批次的数据,实际上每一张卡都会有很长时间的空闲期,它要么在等上一块卡跑完,要么完成了自己这一批的任务,在等待下一批次的数据。
如果我们把一个批次的数据给分成很多小份的话,我们可以让第 0 块卡先算一小份,算完以后立马送给下一块卡,然后再计算下一小份,这样子的话这个时刻卡 0 和卡 1 可以同时算,空置率就下去了。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

这就是流水线并行的一个核心思想,我们看一下它代码怎么实现。
比如在这个里面,我们想要把一份数据给拆成 4 份,我们用 F.split 将它拆成 4 分,然后遍历一遍这 4 份数据,如果它是第一块卡,它就拿那个数据,不然的话它会等,等着接收前一块卡的计算结果。不管怎么样拿到数据以后的事情就是进行计算,计算完以后我们要处理计算结果——和简单模型并行一样,如果他不是最后一块 GPU,我要把它送到下一块,如果它是最后一块 GPU 的话,就直接出来返回结果。
这就是流水线并行。当然到实际场景中流水线并行的代码需要考虑执行效率,没有这么简单,比如说会引入异步 send/recv,以降低等待时间。
我们不光要推理,我们还要训练,训练的话就涉及到一个反传,在普通的模型并行当中,我们的反传和前传时间轴是如下图所示:
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

我们先前传完,再依次反传。但是在我们流水线并行里面,其实反传也是一个流水线的过程。但是这里面有个特殊的地方,注意一下重新前传(或重算)。如果我们不重新前传的话,意味着我们前面的这些中间结果都要保留着等待反传结束后才能丢弃/释放,这意味着我的宝贵的显存又要被浪费了,这样子的话我们还不如算完就全部扔掉,因为我已经把结果交给下一块 GPU 了,暂时就不需要了。而反传时我们还需要中间结果的时候,我大不了再重算一次(换句话说每张卡只要保留自己的输入就可以了)。重算后我们可以正常做反传,得到关于输入的梯度,然后把这份梯度传给上一张卡。上一张卡同样执行重算、反传和发送梯度,直到所有卡都完成了梯度计算。
重新前传的操作叫做 checkpoint 或 sublinear,在 PyTorch 里面有 checkpoint,在 MegEngine 里也有 sublinear,我们目前实现的是非常粗粒度的 sublinear,它不是中间保留几个结果重算部分就可以了,它其实是全部都重算了,这就是 GPipe。
分享实录 | 利用 MegEngine 分布式通信算子实现复杂的并行训练
文章图片

前传还是一样的代码,如上图左侧给大家做一个参考。
反传是精妙的地方,我们拿到 label,loss 以后看一下,第一就是我们 GradManager,这是 MegEngine 一个非常重要的特性,就是 GradManager 可以对中间的 feature(就是中间结果)进行求导,所以我们可以在计算过程中对中间变量进行 attach,在 GPipe 的场景下,我们需要的是对输入的导数,所以我们在一开始就 attach 输入数据 x,然后进行前传(或者称为重算)。如果它是最后一张卡的话,我们就计算相应的损失,并把梯度算出来。通过 grad_to_prev_gpu,我们把关于输入的梯度传给了上一张 GPU。后一块卡关于输入的梯度即前一块卡输出的梯度 dy。我们通过 gm.backward(dy=grad)手动指定梯度,从而完成中间 GPU 的求导过程。这就是一个简单的 GPipe。
如果大家想试着玩一下这个 GPipe 的话,在 GitHub 上面 MegEngine Parallel Tutorial 是我写的,大家可以去跑一下玩一下。
欢迎小伙伴加入我们MegEngine旷视天元开发者交流 QQ 群:1029741705
框架使用相关交流或反馈,欢迎访问论坛: http://discuss.megengine.org.cn;
GitHub 项目地址: http://github.com/MegEngine/M...

    推荐阅读