attention|attention visualization

import matplotlib.pyplotas plt
import numpyas np
def samplemat(dims):
"""Make a matrix with all zeros and increasing elements on the diagonal"""
aa = np.zeros(dims)
for iin range(min(dims)):
aa[i, i] = i
return aa
def draw(sentence="城镇小区配套幼儿园不得办成营利性幼儿园", attention_matrix=None):
text_labels =list(sentence)
# Display matrix
figure = plt.figure()
ax = figure.add_axes([0.1, 0.1, 0.8, 0.8])
font = {"family":"SimHei", "weight":"bold", "size":"8"}# setup font properties for xtick or ytick
ax.set_xticks([ifor iin range(len(text_labels) +2)])# setup xtick position and step
ax.set_yticks([ifor iin range(len(text_labels) +2)])# setup ytick position and step
ax.set_xticklabels(["CLS"] + text_labels + ["SEP"], **font)# setup text label in x axis
ax.set_yticklabels(["CLS"] + text_labels + ["SEP"], **font)# setup text label in y axis
ax.imshow(X=samplemat((len(text_labels)+2, len(text_labels)+2)))# draw matrix
plt.show()
【attention|attention visualization】draw()

    推荐阅读