从来好事天生俭,自古瓜儿苦后甜。这篇文章主要讲述代码实现 加性注意力 | additive attention #51CTO博主之星评选#相关的知识,希望能为你提供帮助。
import math
import torch
from torch import nn
from d2l import torch as d2l
python人必懂的导包,这不用解释了。
def masked_softmax(X, valid_lens):
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
value=https://www.songbingjia.com/android/-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)
一个遮蔽softmax的操作。在nadaraya-waston核回归代码实现中我们做过一个类似的mask操作。就是倒数第三段代码那个位置,每个$x$和除自己本身以外的其他$x_i$进行计算,然后我们使用
X_tile[(1 - torch.eye(n_train)).type(torch.bool)]
将其本身遮盖掉了。也就是mask操作。这个函数的功能是这样的:就是我们传入的一整个张量可能只有一部分是有用的,所以将没用的部分mask掉,只对剩下的部分进行softmax计算。比如我们传入一个长度为5的向量,我们仅需要前两个数据,那经过这个函数之后,后三个数加起来是0,前两个数加起来是1。
- 函数两个参数
X
和valid_lens
,x是要softmax的张量,valid_lens存储每个维度上的有效长度,不管传入一维还是二维,都要确保能进行广播机制。 - 函数一进来是一个if语句
if valid_lens is None
是说如果没有给出valid_lens
,也就是整个张量都是有效的,不需要进行mask之后再softmax,所以if语句直接返回一个普通的softmax操作,函数运行结束。 - 当传入
valid_lens
的时候进入else
- 首先是用
shape
存储待mask的张量X
的shape。 - 又是一个if-else语句,这个是用来处理
valid_lens
长度的,将valid_lens
长度转化矩阵的行数。
- 当
valid_lens
是一维的时候进入if,将其转换为一个mask向量。解释一下,因为mini-batch的存在,所以传入的X
一般是三维的,第一个维度是batch size,二三维度上的才是矩阵的大小。之前用shape
存储X
的shape,现在用shape[1]
取到X
中的矩阵是几行,然后每行的有效元素对应valid_lens
中的数值。
想了解torch.repeat_interleave
看这里→pytorch中的repeat操作对比
- 当
valid_lens
不是一维的时候进入else中。直接将其从一个矩阵转化为一个向量即可。 - 对于mask操作是直接用d2l中的函数实现的,源码我就不去扒了,对于维度的处理记住:
- 如果传入的
valid_lens
是一维的,那valid_lens
的长度要和X
的第二维(shape[1]
)一样。 - 如果传入的
valid_lens
是二维的,那valid_lens
的第一维度要和batch size一样,第二维度要和X
中矩阵的行数一样。 - 具体例子可以看代码实现 缩放点积注意力 | scaled dot-product attention
- 如果传入的
- 当
- 首先是用
class AdditiveAttention(nn.Module):
def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
self.w_v = nn.Linear(num_hiddens, 1, bias=False)
self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
features = queries.unsqueeze(2) + keys.unsqueeze(1)
features = torch.tanh(features)
scores = self.w_v(features).squeeze(-1)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
加性注意力代码部分:
因为这里涉及到一个升到四维张量,所以一定要自己捋一捋。
- 主要的三个参数,
key_size
keys的长度,query_size
query的长度,num_hiddens
隐藏层的大小。因为加性注意力是处理keys和queries长度不一样的情况。 - 三个小的线性层。
self.W_k
和self.W_q
是把key和query转化到隐藏层,self.W_v
是从隐藏层到单个输出。 - 在这里均设置不需要bias
- 最后还做了一下dropout
- 然后是前向传播函数,是计算$a(\\mathbf q, \\mathbf k) = \\mathbf w_v^\\top \\texttanh(\\mathbf W_q\\mathbf q + \\mathbf W_k \\mathbf k)$的过程:
- 将queries和keys扔进前边两个线性层就可以得到queries和keys,进行维度调整。
queries
的形状:(batch_size
, 查询的个数, 1,num_hidden
)
key
的形状:(batch_size
, 1, “键-值”对的个数,num_hiddens
)
- 进行公式的计算。
scores
的计算是self.w_v
仅有一个输出,因此从形状中移除最后那个维度。scores
的形状:(batch_size
, 查询的个数, “键-值”对的个数)- 最后
values
的形状:(batch_size
, “键-值”对的个数, 值的维度)
- 将queries和keys扔进前边两个线性层就可以得到queries和keys,进行维度调整。
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# `values` 的小批量数据集中,两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
2, 1, 1)
valid_lens = torch.tensor([2, 6])attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1)
attention.eval()
attention(queries, keys, values, valid_lens)
带入一个样例测试一下子。
注意这里使用到
.eval()
,是不启用 BatchNormalization 和 Dropout。d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel=Keys, ylabel=Queries)
因为和代码实现 缩放点积注意力 | scaled dot-product attention用的数据都一样的,所以就不具体解析这个热图了,不懂的可以看点积缩放注意力那篇文章的热图分析。
【代码实现 加性注意力 | additive attention #51CTO博主之星评选#】
文章图片
推荐阅读
- MySql数据库增删改查常用语句命令
- #yyds干货盘点#安装悟空CRM
- Docker基本管理
- Java桥接方法
- prometheus基于文件服务发现
- Ansible配置文件命令及模块
- Docker Harbor私有仓库部署于管理
- Docker 数据管理
- nginx一个端口配置多域名服务