增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)

论文《Learning to learn without forgetting by maximizing transfer and minimizing interference》中提出了“将经验重放与元学习相结合“的增量学习方法:Meta-Experience Replay (MER)。
这里整理了一下MER的算法流程和代码实现,分别针对任务增量(Task-IL)和类增量(Class-IL)场景下。
论文解析可以戳这里:??????论文解析:Learning to learn without forgetting by maximizing transfer and minimizing interference
目录
1. 算法基础
1.1 Reservior Sampling (蓄水池采样)
1.2 Experience Replay (ER,经验回放方法)
1.3 Reptile
2. Meta-Experience Replay 算法
2.1 MER 算法详解
2.2 任务增量下的代码注释
2.3 类增量下的代码注释
1. 算法基础 1.1 Reservior Sampling (蓄水池采样) Reservior Sampling 是基于经验重放的增量学习方法中常使用的等概率采样方法。
(1) 原理
给出一个数据流,这个数据流的长度很大或者未知。并且对该数据流中数据只能访问一次。写出一个随机选择算法,使得数据流中所有数据被选中的概率相等。
(2) 方法

假设需要采样的数量为k。
首先构建一个可容纳k个元素的数据,将序列的前k个元素放入数据。
然后对第j个元素(j>k),以k/j的概率决定该元素是否被留下(替换到数组中,数组中的k个元素被替换的概率相同)。
(3) 证明
增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片

(4) 算法流程
增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片

1.2 Experience Replay (ER,经验回放方法) (1) 学习目标
核心是保持对已经见过的exemplars的记忆
目标函数:增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片

其中,增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片
为 memory buffer,current size = 增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片
, maximum size = 增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片

原理:使用 Reservioe Sampling 更新 buffer,确保在每一个时间步长里,任何 增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片
个exemplars在 buffer 中被看见的概率都等于 增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片

(2) 算法流程
ER 算法中,每看到新的样本,就对当前 exemplars 进行优先级排序。确保 current exemplars 与 replay buffer 中的例子交叉(因为在继续next example前,希望确保算法能够对current example进行优化,特别是当它还未加入到memory中)
增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片

1.3 Reptile Reptile是元学习中最经典和常用的算法之一。具体的原理可以自行查阅相关文献。
本文的MER就是在Reptile基础上结合增量学习,Reptile基于SGD优化器和学习率增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片
,跨s批次顺序优化。
在a set of s batches上的优化目标为:
增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片

算法流程:
增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片


2. Meta-Experience Replay 算法 这里主要介绍论文中的 Algorithm 1,是单个样本的增量更新。(Algorithm 6 是对一个批次batch的增量更新,原理和代码相差不大。)
2.1 MER 算法详解 原理:MER保持着 Experience Replay 的记忆增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片
,通过 Reservior Sampling 采样。每次时间步提取包括从buffer中k-1个随机样本在内的s个batches。
流程:
1、黄色框为内部更新 inner update:
在Reptile的基础流程。对于 s 个 batches,每个 batch 中的 k 个样本,都进行1次Reptile批处理。
2、绿色框为外部更新 outer update:
根据 inner update 后的模型参数,更新原始模型参数。使用Reservior sampling来更新 memory buffer。
增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片


2.2 任务增量下的代码注释 代码链接:MER/meralg1.py at master · mattriemer/MER · GitHub
(1) Draw batches from buffer:
当前的新样本为 (x,y),结合新样本和从 memory buffer 增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片
中取出的旧样本(经验回放),生成该批次要训练的样本:增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)
文章图片

def getBatch(self, x, y, t): # (x,y): 新看到的样本 xi = Variable(torch.from_numpy(np.array(x))).float().view(1, -1) yi = Variable(torch.from_numpy(np.array(y))).long().view(1) if self.cuda: xi = xi.cuda() yi = yi.cuda()# bxs, bys: 该批次要训练的样本 bxs = [xi] bys = [yi]if len(self.M) > 0: order = [i for i in range(0, len(self.M))] osize = min(self.batchSize, len(self.M)) for j in range(0, osize): shuffle(order) k = order[j] x, y, t = self.M[k] xi = Variable(torch.from_numpy(np.array(x))).float().view(1, -1) yi = Variable(torch.from_numpy(np.array(y))).long().view(1) # handle gpus if specified if self.cuda: xi = xi.cuda() yi = yi.cuda() bxs.append(xi) bys.append(yi)return bxs, bys

在 observe() 中调用:
# Draw batch from buffer bxs,bys = self.getBatch(xi,yi,t)

(2) Inner update 中使用Reptile meta-update:
for step in range(0, self.steps): weights_before = deepcopy(self.net.state_dict()) # Draw batch from buffer: bxs, bys = self.getBatch(xi, yi, t) loss = 0.0 for idx in range(len(bxs)): # 单个样本进行元学习 self.net.zero_grad() bx = bxs[idx] by = bys[idx] prediction = self.forward(bx, 0) loss = self.bce(prediction, by) loss.backward() self.opt.step()weights_after = self.net.state_dict()# Within batch Reptile meta-update: # 更新内部模型的参数 self.net.load_state_dict( {name: weights_before[name] + ((weights_after[name] - weights_before[name]) * self.beta) for name in weights_before})

(3) Outer update 中进行 Reptile 元更新和重新采样
第一步,将内部更新的元模型参数进行外部模型的更新:
# Across batch Reptile meta-update self.net.load_state_dict({name : before[name] + ((after[name] - before[name]) * self.gamma) for name in before})

第二步,使用 Reservoir Sampling 更新 buffer memory:
# Reservoir sampling memory update: if len(self.M) < self.memories: self.M.append([xi, yi, t]) else: p = random.randint(0, self.age) if p < self.memories: self.M[p] = [xi, yi, t]

2.3 类增量下的代码注释 代码链接:La-MAML/meralg1.py at main · montrealrobotics/La-MAML · GitHub
(0) initialization初始化:
在__init__() 函数中,根据类增量的场景,重新设置了每个任务的类别数 nc_per_task
self.n_outputs = n_outputs if self.is_cifar:# Class-IL self.nc_per_task = n_outputs / n_tasks# 每个任务的类别不重叠 else:# Task -IL self.nc_per_task = n_outputs# 每个任务的类别可以看作一样

(1) Draw batches from buffer:
与2.2中的如出一辙,但是多增加了任务t。将任务t也加入到了buffer memory中。
(2) Inner update 中使用Reptile meta-update:
在这里,类增量比2.2(任务增量)新增了一个 compute_offsets() 函数。主要是因为任务增量中,验证的是所有类别(每个任务的类别可以近似看作是一样的)的预测结果;而类增量中,验证的是当前任务中涉及到的类别(每个任务的类别都不重叠)的预测结果。
所以,使用 compute_offsets() 函数来框定只有在这次任务出现的类别:
def compute_offsets(self, task): if self.is_cifar: # Class-IL offset1 = task * self.nc_per_task offset2 = (task + 1) * self.nc_per_task else: # Task-IL offset1 = 0 offset2 = self.n_outputs return int(offset1), int(offset2)

同样,在 forward() 中也应用了compute_offsets() 函数来框定只有在这次任务中出现的类别的预测结果:
def forward(self, x, t): output = self.netforward(x) if self.is_cifar: offset1, offset2 = self.compute_offsets(t)# 不在offset1~offset2的预测结果都剔除 if offset1 > 0: output[:, :offset1].data.fill_(-10e10) if offset2 < self.n_outputs: output[:, int(offset2):self.n_outputs].data.fill_(-10e10) return output

在observe() 函数中,inner update内部更新流程:
for step in range(0, self.steps): weights_before = deepcopy(self.net.state_dict()) ##Check for nan if weights_before != weights_before: ipdb.set_trace() # Draw batch from buffer: bxs, bys, bts = self.getBatch(xi, yi, t) loss = 0.0 total_loss = 0.0 for idx in range(len(bxs)):self.net.zero_grad() bx = bxs[idx] by = bys[idx] bt = bts[idx]if self.is_cifar:# Class-IL offset1, offset2 = self.compute_offsets(bt)# 获得当前任务的index prediction = (self.netforward(bx)[:, offset1:offset2])# 获得当前任务的预测结果 loss = self.bce(prediction, by.unsqueeze(0) - offset1) else:# Task-IL prediction = self.forward(bx, 0) loss = self.bce(prediction, by.unsqueeze(0)) if torch.isnan(loss): ipdb.set_trace()loss.backward() torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.args.grad_clip_norm) self.opt.step() total_loss += loss.item() weights_after = self.net.state_dict() if weights_after != weights_after: ipdb.set_trace()# Within batch Reptile meta-update: self.net.load_state_dict( {name: weights_before[name] + ((weights_after[name] - weights_before[name]) * self.beta) for name in weights_before})

(3) Outer update 中进行 Reptile 元更新和重新采样
与2.2中的如出一辙,但是多增加了任务t。将任务t也加入到了buffer memory中。

以上是我对任务增量/类增量场景下的MER代码的一些理解,从代码上也可以看出任务增量和类增量的异同。如果有写的不对的地方,欢迎指出与讨论~
citation:M. Reimer, I. Cases, R. Ajemian, M. Liu, I. Rish, Y. Tu, G. Tesauro, Learning to learn without forgetting by maximizing transfer and minimizing interference, in: ICLR, 2019.




【增量学习|【元学习】MER代码实现(Task/Class-IL增量场景下的Meta-Experience Replay详解)】

    推荐阅读