代码实现 加性注意力 | additive attention #51CTO博主之星评选#

从来好事天生俭,自古瓜儿苦后甜。这篇文章主要讲述代码实现 加性注意力 | additive attention #51CTO博主之星评选#相关的知识,希望能为你提供帮助。

import math import torch from torch import nn from d2l import torch as d2l

python人必懂的导包,这不用解释了。
def masked_softmax(X, valid_lens): if valid_lens is None: return nn.functional.softmax(X, dim=-1) else: shape = X.shape if valid_lens.dim() == 1: valid_lens = torch.repeat_interleave(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=https://www.songbingjia.com/android/-1e6) return nn.functional.softmax(X.reshape(shape), dim=-1)

一个遮蔽softmax的操作。在nadaraya-waston核回归代码实现中我们做过一个类似的mask操作。就是倒数第三段代码那个位置,每个$x$和除自己本身以外的其他$x_i$进行计算,然后我们使用X_tile[(1 - torch.eye(n_train)).type(torch.bool)]将其本身遮盖掉了。也就是mask操作。
这个函数的功能是这样的:就是我们传入的一整个张量可能只有一部分是有用的,所以将没用的部分mask掉,只对剩下的部分进行softmax计算。比如我们传入一个长度为5的向量,我们仅需要前两个数据,那经过这个函数之后,后三个数加起来是0,前两个数加起来是1。
  • 函数两个参数Xvalid_lens,x是要softmax的张量,valid_lens存储每个维度上的有效长度,不管传入一维还是二维,都要确保能进行广播机制。
  • 函数一进来是一个if语句if valid_lens is None是说如果没有给出valid_lens,也就是整个张量都是有效的,不需要进行mask之后再softmax,所以if语句直接返回一个普通的softmax操作,函数运行结束。
  • 当传入valid_lens的时候进入else
    • 首先是用shape存储待mask的张量X的shape。
    • 又是一个if-else语句,这个是用来处理valid_lens长度的,将valid_lens长度转化矩阵的行数。
      • valid_lens是一维的时候进入if,将其转换为一个mask向量。解释一下,因为mini-batch的存在,所以传入的X一般是三维的,第一个维度是batch size,二三维度上的才是矩阵的大小。之前用shape存储X的shape,现在用shape[1]取到X中的矩阵是几行,然后每行的有效元素对应valid_lens中的数值。
        想了解torch.repeat_interleave看这里→pytorch中的repeat操作对比
      • valid_lens不是一维的时候进入else中。直接将其从一个矩阵转化为一个向量即可。
      • 对于mask操作是直接用d2l中的函数实现的,源码我就不去扒了,对于维度的处理记住:
        • 如果传入的valid_lens是一维的,那valid_lens的长度要和X的第二维(shape[1])一样。
        • 如果传入的valid_lens是二维的,那valid_lens的第一维度要和batch size一样,第二维度要和X中矩阵的行数一样。
        • 具体例子可以看代码实现 缩放点积注意力 | scaled dot-product attention
class AdditiveAttention(nn.Module): def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs): super(AdditiveAttention, self).__init__(**kwargs) self.W_k = nn.Linear(key_size, num_hiddens, bias=False) self.W_q = nn.Linear(query_size, num_hiddens, bias=False) self.w_v = nn.Linear(num_hiddens, 1, bias=False) self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens): queries, keys = self.W_q(queries), self.W_k(keys) features = queries.unsqueeze(2) + keys.unsqueeze(1) features = torch.tanh(features) scores = self.w_v(features).squeeze(-1) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values)

加性注意力代码部分:
因为这里涉及到一个升到四维张量,所以一定要自己捋一捋。
  • 主要的三个参数,key_sizekeys的长度, query_sizequery的长度, num_hiddens隐藏层的大小。因为加性注意力是处理keys和queries长度不一样的情况。
  • 三个小的线性层。self.W_kself.W_q是把key和query转化到隐藏层,self.W_v是从隐藏层到单个输出。
  • 在这里均设置不需要bias
  • 最后还做了一下dropout
  • 然后是前向传播函数,是计算$a(\\mathbf q, \\mathbf k) = \\mathbf w_v^\\top \\texttanh(\\mathbf W_q\\mathbf q + \\mathbf W_k \\mathbf k)$的过程:
    • 将queries和keys扔进前边两个线性层就可以得到queries和keys,进行维度调整。
      queries 的形状:(batch_size, 查询的个数, 1, num_hidden)
      key 的形状:(batch_size, 1, “键-值”对的个数, num_hiddens)
    • 进行公式的计算。
    • scores的计算是self.w_v 仅有一个输出,因此从形状中移除最后那个维度。
      scores 的形状:(batch_size, 查询的个数, “键-值”对的个数)
    • 最后values 的形状:(batch_size, “键-值”对的个数, 值的维度)
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2)) # `values` 的小批量数据集中,两个值矩阵是相同的 values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat( 2, 1, 1) valid_lens = torch.tensor([2, 6])attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1) attention.eval() attention(queries, keys, values, valid_lens)

带入一个样例测试一下子。
注意这里使用到.eval(),是不启用 BatchNormalization 和 Dropout。
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)), xlabel=Keys, ylabel=Queries)

因为和代码实现 缩放点积注意力 | scaled dot-product attention用的数据都一样的,所以就不具体解析这个热图了,不懂的可以看点积缩放注意力那篇文章的热图分析。
【代码实现 加性注意力 | additive attention #51CTO博主之星评选#】
代码实现 加性注意力 | additive attention #51CTO博主之星评选#

文章图片


    推荐阅读