attention|attention visualization
import matplotlib.pyplotas plt
import numpyas np
def samplemat(dims):
"""Make a matrix with all zeros and increasing elements on the diagonal"""
aa = np.zeros(dims)
for iin range(min(dims)):
aa[i, i] = i
return aa
def draw(sentence="城镇小区配套幼儿园不得办成营利性幼儿园", attention_matrix=None):
text_labels =list(sentence)
# Display matrix
figure = plt.figure()
ax = figure.add_axes([0.1, 0.1, 0.8, 0.8])
font = {"family":"SimHei", "weight":"bold", "size":"8"}# setup font properties for xtick or ytick
ax.set_xticks([ifor iin range(len(text_labels) +2)])# setup xtick position and step
ax.set_yticks([ifor iin range(len(text_labels) +2)])# setup ytick position and step
ax.set_xticklabels(["CLS"] + text_labels + ["SEP"], **font)# setup text label in x axis
ax.set_yticklabels(["CLS"] + text_labels + ["SEP"], **font)# setup text label in y axis
ax.imshow(X=samplemat((len(text_labels)+2, len(text_labels)+2)))# draw matrix
plt.show()
【attention|attention visualization】draw()
推荐阅读
- Spring|Spring Boot之ImportSelector
- Spring注解05|Spring注解05 @Import 给容器快速导入一个组件
- 026-Catagory-NSString
- 1.2.1
- Figure|Figure 图像
- iOS|iOS MD5加密
- 录音
- mongosql
- 检查元素是否在List中
- Spring配置