pytorch|BERT、ALBERT模型加载——From pytorch_model.bin

我在载入BERT/ALBERT的预训练模型时,总会好奇于它的模型参数到底是怎么一步步被填到模型框架里的,此外我也想更明晰地看到模型参数是否被正确地填入,以防预训练的模型参数没被正确载入。因此对BERT模型的载入代码进行了单步调试,在此简述这部分代码各自的作用。 注:根目录是albert-pytorch项目根目录,来自github该repo
模型文件加载的文件跳转路径:

/run_classifier.py(387) AlbertForSequenceClassification.from_pretrained()-> /model/modeling_utils.py(191) from_pretrained() -> /model/modeling_utils.py(363) load() -> /model/modeling_utils.py(347) load() -># 这是个递归函数,在一次次递归中"prefix"参数在变化,控制着模型参数的载入; /model/modeling_utils.py(347) module._load_from_state_dict() -> {$TORCH_HOME}/nn/modules/module.py(703) _load_from_state_dict()

重点就在这函数_load_from_state_dict()里面。line742~line769的for-loop。若是成功的模型参数加载,则line762:param.copy_(input_param)就会被执行(这段for-loop代码示例如下)
742 for name, param in local_state.items(): 743key = prefix + name 744if key in state_dict: 745input_param = state_dict[key] 746 747# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ 748if len(param.shape) == 0 and len(input_param.shape) == 1: 749input_param = input_param[0] 750 751if input_param.shape != param.shape: 752# local shape should match the one in checkpoint 753error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 754'the shape in current model is {}.' 755.format(key, input_param.shape, param.shape)) 756continue 757 758if isinstance(input_param, Parameter): 759# backwards compatibility for serialized parameters 760input_param = input_param.data 761try: 762param.copy_(input_param) 763except Exception: 764error_msgs.append('While copying the parameter named "{}", ' 765'whose dimensions in the model are {} and ' 766'whose dimensions in the checkpoint are {}.' 767.format(key, param.size(), input_param.size())) 768elif strict: 769missing_keys.append(key)

这里param是一个torch.Tensor,让我们读一下torch.Tensor.copy_()的函数文档
Copies the elements from src into self tensor and returns self. The
src tensor must be broadcastable with the self tensor. It may be of a
different data type or reside on a different device.
很简单,意思就是,src=input_param会被复制到self当中(当前self就是当前param所在的nn.Module),同时input_param会作为返回值。
以参数bert.embeddings.word_embeddings.weight为例子,此时该param所对应的“self”是Embedding(21128, 128, padding_idx=0),所在层就是Embedding-layer。我们从class torch.nn.Embedding可以看出(下面附source code),该层含有num_embeddings、embedding_dim等属性。它们分别就是21128, 128(前者是vocab词表大小,后者是albert的Embedding size)
class Embedding(Module): def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False, _weight=None): super(Embedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' elif padding_idx < 0: assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq if _weight is None: self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() else: assert list(_weight.shape) == [num_embeddings, embedding_dim], \ 'Shape of weight does not match num_embeddings and embedding_dim' self.weight = Parameter(_weight) self.sparse = sparse

注意line 18~20。当还未执行param.copy_(input_param)时,此时Embedding层的参数还是未初始化的,其中line20的self.reset_parameters()内部的操作就是,将前面指定维度生成的Tensor,填以服从N(0, 1)的正态分布的随机数。我们此时先把这个self.weight打出来看看:
ipdb> self.weight Parameter containing: tensor([[-0.0072, -0.0040,0.0490,..., -0.0219,0.0050, -0.0293], [ 0.0405,0.0166, -0.0039,..., -0.0099, -0.0004, -0.0137], [ 0.0111, -0.0048,0.0283,...,0.0047, -0.0072,0.0130], ..., [ 0.0209, -0.0084, -0.0283,...,0.0367,0.0080, -0.0220], [ 0.0584,0.0286,0.0028,..., -0.0016,0.0436,0.0071], [ 0.0238, -0.0204,0.0172,..., -0.0435, -0.0267,0.0099]], requires_grad=True)

执行过param.copy_(input_param)后,看看self.weight是否被修改成bert.embeddings.word_embeddings.weight的内容了:
ipdb> self.weight Parameter containing: tensor([[ 0.0722,0.0224,0.1045,...,0.0800,0.0776, -0.0483], [ 0.0779,0.0606,0.0891,...,0.0628,0.0831, -0.0924], [ 0.0891,0.0782,0.0731,...,0.0609,0.1201, -0.0561], ..., [ 0.0159,0.0438,0.1095,...,0.0802,0.0773, -0.0790], [ 0.0664,0.0513,0.1075,...,0.0682,0.0776, -0.0842], [ 0.0135,0.0239,0.1113,...,0.0646,0.0756, -0.0632]], requires_grad=True)

这恰好就是bert.embeddings.word_embeddings.weight对应的值(input_param):
ipdb> (input_param == self.weight).numpy().all() True

【pytorch|BERT、ALBERT模型加载——From pytorch_model.bin】这说明了line762的param.copy_(input_param)就是在将input_param更新到self.weight上去。

    推荐阅读