python|PyTorch单机多卡分布式训练教程及代码示例

导师不是很懂PyTorch的分布式训练流程,我就做了个PyTorch单机多卡的分布式训练介绍,但是他觉得我做的没这篇好PyTorch分布式训练简明教程 - 知乎。这篇讲的确实很好,不过我感觉我做的也还可以,希望大家看完之后能给我一些建议。
目录
1.预备知识
1.1 主机(Host),节点(Node),进程(Process)和工作结点(Worker)。
1.2 World,Rank,Local Rank
1.2.1 World
1.2.2 Rank
1.2.3 Local Rank
2. PyTorch单机多卡数据并行
2.1 多进程启动
2.1.1 多进程启动示例
2.2 启动进程间通信
2.2.1 初始化成功示例
2.2.2 初始化失败示例
2.2.3 进程间通信示例
2.3. 单机多卡数据并行示例
后记:如何拓展到多机多卡?
1.预备知识 多卡训练涉及到多进程和进程间通信,因此有必要先解释一些进程间通信的概念。
1.1 主机(Host),节点(Node),进程(Process)和工作结点(Worker)。 众所周知,每个主机都可以同时运行多个进程,但是在通常情况下每个进程都是做各自的事情,各个进程间是没有关系的。
而在MPI中,我们可以拥有一组能够互相发消息的进程,但是这些进程可以分布在多个主机中,这时我们可以将主机称为节点(Node),进程称为工作结点(Worker)。
python|PyTorch单机多卡分布式训练教程及代码示例
文章图片


由于PyTorch中的主要说法还是进程,所以后面也会统一采用主机和进程的说法。
1.2 World,Rank,Local Rank 对于一组能够互相发消息的进程,我们需要区分每一个进程,因此每个进程会被分配一个序号,称作rank。进程间可以通过指定rank来进行通信。
1.2.1 World
World可以认为是一个集合,由一组能够互相发消息的进程组成。
如下图中假如Host 1的所有进程和Host 2的所有进程都可以进行通信,那么它们就组成了一个World。

python|PyTorch单机多卡分布式训练教程及代码示例
文章图片

因此,world size就表示这组能够互相通信的进程的总数,上图中world size为6。
1.2.2 Rank
Rank可以认为是这组能够互相通信的进程在World中的序号。
python|PyTorch单机多卡分布式训练教程及代码示例
文章图片


1.2.3 Local Rank
Local Rank可以认为是这组能够互相通信的进程在它们相应主机(Host)中的序号。
即在每个Host中,Local rank都是从0开始。
python|PyTorch单机多卡分布式训练教程及代码示例
文章图片


2. PyTorch单机多卡数据并行 数据并行本质上就是增大模型的batch size,但batch size也不是越大越好,所以一般对于大模型才会使用数据并行。
Pytorch进行数据并行主要依赖于它的两个模块multiprocessing和distributed。
所以首先介绍multiprocessing和distributed模块的基本用法。
2.1 多进程启动 由于Python多线程存在GIL(全局解释器锁),为了提高效率,Pytorch实现了一个multiprocessing多进程模块。用于在一个Python进程中启动额外的进程。
2.1.1 多进程启动示例
该程序启动了4个进程,每个进程会输出当前rank,表明与其他进程不同。

#run_multiprocess.py #运行命令:python run_multiprocess.py import torch.multiprocessing as mpdef run(rank, size): print("world size:{}. I'm rank {}.".format(size,rank))if __name__ == "__main__": world_size = 4 mp.set_start_method("spawn") #创建进程对象 #target为该进程要运行的函数,args为target函数的输入参数 p0 = mp.Process(target=run, args=(0, world_size)) p1 = mp.Process(target=run, args=(1, world_size)) p2 = mp.Process(target=run, args=(2, world_size)) p3 = mp.Process(target=run, args=(3, world_size))#启动进程 p0.start() p1.start() p2.start() p3.start()#当前进程会阻塞在join函数,直到相应进程结束。 p0.join() p1.join() p2.join() p3.join()

输出结果:
world size:4. I'm rank 1. world size:4. I'm rank 0. world size:4. I'm rank 2. world size:4. I'm rank 3.

2.2 启动进程间通信 虽然启动了多进程,但是此时进程间并不能进行通信,所以PyTorch设计了另一个distributed模块用于进程间通信。
init_process_group函数是distributed模块用于初始化通信模块的函数。
当该函数初始化成功则表明进程间可以进行通信。
2.2.1 初始化成功示例
只有当world size和实际启动的进程数匹配,init_process_group才可以初始化成功。
#multiprocess_comm.py #运行命令:python multiprocess_comm.pyimport os import torch.distributed as dist import torch.multiprocessing as mpdef run(rank, size): #MASTER_ADDR和MASTER_PORT是通信模块初始化需要的两个环境变量。 #由于是在单机上,所以用localhost的ip就可以了。 os.environ['MASTER_ADDR'] = '127.0.0.1' #端口可以是任意空闲端口 os.environ['MASTER_PORT'] = '29500' #通信模块初始化 #进程会阻塞在该函数,直到确定所有进程都可以通信。 dist.init_process_group('gloo', rank=rank, world_size=size) print("world size:{}. I'm rank {}.".format(size,rank))if __name__ == "__main__": world_size = 4 mp.set_start_method("spawn") #创建进程对象 #target为该进程要运行的函数,args为函数的输入参数 p0 = mp.Process(target=run, args=(0, world_size)) p1 = mp.Process(target=run, args=(1, world_size)) p2 = mp.Process(target=run, args=(2, world_size)) p3 = mp.Process(target=run, args=(3, world_size))#启动进程 p0.start() p1.start() p2.start() p3.start()#等待进程结束 p0.join() p1.join() p2.join() p3.join()

输出结果:
world size:4. I'm rank 1. world size:4. I'm rank 0. world size:4. I'm rank 2. world size:4. I'm rank 3.

2.2.2 初始化失败示例
当将world size设置为2,但是实际却启动了4个进程,此时init_process_group就会报错。
#multiprocess_comm.py #运行命令:python multiprocess_comm.pyimport os import torch.distributed as dist import torch.multiprocessing as mpdef run(rank, size): #MASTER_ADDR和MASTER_PORT是通信模块初始化需要的两个环境变量。 #由于是在单机上,所以用localhost的ip就可以了。 os.environ['MASTER_ADDR'] = '127.0.0.1' #端口可以是任意空闲端口 os.environ['MASTER_PORT'] = '29500' #通信模块初始化 #进程会阻塞在该函数,直到确定所有进程都可以通信。 dist.init_process_group('gloo', rank=rank, world_size=size) print("world size:{}. I'm rank {}.".format(size,rank))if __name__ == "__main__": world_size = 2 mp.set_start_method("spawn") #创建进程对象 #target为该进程要运行的函数,args为target函数的输入参数 p0 = mp.Process(target=run, args=(0, world_size)) p1 = mp.Process(target=run, args=(1, world_size)) p2 = mp.Process(target=run, args=(2, world_size)) p3 = mp.Process(target=run, args=(3, world_size))#启动进程 p0.start() p1.start() p2.start() p3.start()#当前进程会阻塞在join函数,直到相应进程结束。 p0.join() p1.join() p2.join() p3.join()

输出结果:
RuntimeError: [enforce fail at /opt/conda/conda-bld/pytorch_1623448224956/work/third_party/gloo/gloo/context.cc:27] rank < size. 3 vs 2

2.2.3 进程间通信示例
当init_process_group初始化成功,进程间就可以进行通信了,这里我以集体通信Allreduce为例。
#multiprocess_allreduce.py #运行命令:python multiprocess_allreduce.pyimport os import torch import torch.distributed as dist import torch.multiprocessing as mpdef run(rank, size): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29500' #通信模块初始化 #进程会阻塞在该函数,直到确定所有进程都可以通信。 dist.init_process_group('gloo', rank=rank, world_size=size) #每个进程都创建一个Tensor,Tensor值为该进程相应rank。 param = torch.tensor([rank]) print("rank {}: tensor before allreduce: {}".format(rank,param)) #对该Tensor进行Allreduce。 dist.all_reduce(param.data, op=dist.ReduceOp.SUM) print("rank {}: tensor after allreduce: {}".format(rank,param))if __name__ == "__main__": world_size = 4 mp.set_start_method("spawn") #创建进程对象 #target为该进程要运行的函数,args为target函数的输入参数 p0 = mp.Process(target=run, args=(0, world_size)) p1 = mp.Process(target=run, args=(1, world_size)) p2 = mp.Process(target=run, args=(2, world_size)) p3 = mp.Process(target=run, args=(3, world_size))#启动进程 p0.start() p1.start() p2.start() p3.start()#当前进程会阻塞在join函数,直到相应进程结束。 p0.join() p1.join() p2.join() p3.join()

输出结果:
rank 0: tensor before allreduce: tensor([0]) rank 2: tensor before allreduce: tensor([2]) rank 3: tensor before allreduce: tensor([3]) rank 1: tensor before allreduce: tensor([1])rank 0: tensor after allreduce: tensor([6]) rank 3: tensor after allreduce: tensor([6]) rank 2: tensor after allreduce: tensor([6]) rank 1: tensor after allreduce: tensor([6])

2.3. 单机多卡数据并行示例 当可以启动多进程,并进行进程间通信后,实际上就已经可以进行单机多卡的分布式训练了。
但是Pytorch为了便于用户使用,所以在这之上又增加了很多更高层的封装,如DistributedDataParallel,DistributedSampler等。
所以为了便于理解这中间的一些流程,这里演示一下不使用这些封装时的单机多卡数据并行。
该示例代码和单机训练主要有两个区别:
(1)需要基于每个进程的rank将模型参数放置到不同的GPU。
(2) 在参数更新前需要对梯度进行Allreduce。
#multiprocess_training.py #运行命令:python multiprocess_training.py import os import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torchvision import torchvision.transforms as transforms #用于平均梯度的函数 def average_gradients(model): size = float(dist.get_world_size()) for param in model.parameters(): dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) param.grad.data /= size #模型 class ConvNet(nn.Module): def __init__(self, num_classes=10): super(ConvNet, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)) self.layer2 = nn.Sequential( nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)) self.fc = nn.Linear(7*7*32, num_classes) def forward(self, x): out = self.layer1(x) out = self.layer2(out) out = out.reshape(out.size(0), -1) out = self.fc(out) return outdef accuracy(outputs,labels): _, preds = torch.max(outputs, 1) # taking the highest value of prediction. correct_number = torch.sum(preds == labels.data) return (correct_number/len(preds)).item()def run(rank, size): #MASTER_ADDR和MASTER_PORT是通信模块初始化需要的两个环境变量。 #由于是在单机上,所以用localhost的ip就可以了。 os.environ['MASTER_ADDR'] = '127.0.0.1' #端口可以是任意空闲端口 os.environ['MASTER_PORT'] = '29500' dist.init_process_group('gloo', rank=rank, world_size=size)#1.数据集预处理 train_dataset = torchvision.datasets.MNIST(root='../data', train=True, transform=transforms.ToTensor(), download=True) training_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)#2.搭建模型 #device = torch.device("cuda:{}".format(rank)) device = torch.device("cpu") print(device) torch.manual_seed(0) model = ConvNet().to(device) torch.manual_seed(rank) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr = 0.001,momentum=0.9) # fine tuned the lr #3.开始训练 epochs = 15 batch_num = len(training_loader) running_loss_history = [] for e in range(epochs): for i,(inputs, labels) in enumerate(training_loader): inputs = inputs.to(device) labels = labels.to(device) #前向传播 outputs = model(inputs) loss = criterion(outputs, labels) optimizer.zero_grad() #反传 loss.backward() #记录loss running_loss_history.append(loss.item()) #参数更新前需要Allreduce梯度。 average_gradients(model) #参数更新 optimizer.step() if (i + 1) % 50 == 0 and rank == 0: print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f},acc:{:.2f}'.format(e + 1, epochs, i + 1, batch_num,loss.item(),accuracy(outputs,labels)))if __name__ == "__main__": world_size = 4 mp.set_start_method("spawn") #创建进程对象 #target为该进程要运行的函数,args为target函数的输入参数 p0 = mp.Process(target=run, args=(0, world_size)) p1 = mp.Process(target=run, args=(1, world_size)) p2 = mp.Process(target=run, args=(2, world_size)) p3 = mp.Process(target=run, args=(3, world_size))#启动进程 p0.start() p1.start() p2.start() p3.start()#当前进程会阻塞在join函数,直到相应进程结束。 p0.join() p1.join() p2.join() p3.join()

后记:如何拓展到多机多卡? 在多机多卡环境中初始化init_process_group还需要做一些额外的处理,主要考虑两个问题
(1)需要让其余进程知道rank=0进程的 IP:Port 地址,此时rank=0进程会在相应端口进行监听,其余进程则会给这个IP:Port发消息。这样rank=0进程就可以进行统计,确认初始化是否成功。这一步在PyTorch中是通过设置os.environ['MASTER_ADDR']和os.environ['MASTER_PORT']这两个环境变量来做的。
(2)需要为每个进程确定相应rank,通常采用的做法是给主机编号,因此多机多卡启动时给不同主机传入的参数肯定是不同的。此时参数可以直接手动在每个主机的代码上修改,也可以通过argparse模块在运行时传递不同参数来做。


【python|PyTorch单机多卡分布式训练教程及代码示例】

    推荐阅读