图神经网络|MPNN(message passing neural networks)


Neural Message Passing for Quantum Chemistry阅读笔记

  • 1. 摘要
  • 2. 贡献
  • 3. 信息传递网络
    • 3.1卷积网络的分子指纹学习 ( Convolutional net-works on graphs for learning molecular fingerprints)
    • 3.2 门图神经网络(Gated Graph Neural Networks ([GG-NN](https://arxiv.org/abs/1511.05493)))
    • 3.3 Interaction Networks(Interaction networks for learning about objects, relations and physics)
    • 3.4 Molecular Graph Convolutions(Molecular graph convolutions:Moving beyond fingerprints)
    • 3.5 Deep Tensor Neural Networks( Quantum chemical insights from deep tensor neural networks)
    • 3.6 Laplacian Based Methods(Semi-Supervised Classification with Graph Convolutional Networks. )
  • 4. MPNN变体
    • 4.1 信息函数(message function)
    • 4.2 虚拟图元素(很不错的idea)
    • 4.3 readout function
    • Multiply Towers

1. 摘要 论文将当前存在的模型整合起来合成MPNNs,在框架内搜新的变量。模型要将输入的分子图通过MPNNs 和信息聚合算出一个函数,这是一个监督学习模型。
2. 贡献
  • MPNN在13个靶向物中达到很好的效果,而且在13个靶向物中的11个进行了预测DFT准确性。
  • 发展了几个不同的MPNNs模型来预测13个靶向物中的5个化合物DFT( Density Functional Theory)的准确性,
  • 我们发展了一个通用算法来训练大结点表示的MPNNs, 比起以前的大节点表示节省了大量的计算时间和资源。
3. 信息传递网络 【图神经网络|MPNN(message passing neural networks)】通用格式图神经网络|MPNN(message passing neural networks)
文章图片

  • N(v)表示顶点v的领域
  • M t M_t Mt?是信息传递函数
  • U t U_t Ut?是顶点更新函数
最后整个图的结点通过readout层函数R,整合成一个representationy ^ \hat y y^?,
图神经网络|MPNN(message passing neural networks)
文章图片

  • R 可微函数,且计算结果不随 h v h_v hv?的排列而变化(目的是保证图同构时保证结果的不变)
3.1卷积网络的分子指纹学习 ( Convolutional net-works on graphs for learning molecular fingerprints)
  • 信息传递函数用的是: M ( h v , h w , e v w ) = ( h w , e v w ) M(h_v,h_w,e_{vw})=(h_w,e_{vw}) M(hv?,hw?,evw?)=(hw?,evw?),这里括号是concatenation
  • 顶点更新函数 U t ( h v , m v t + 1 ) = σ ( H t d e g ( v ) m v t + 1 ) U_t(h_v,m_v^{t+1})=\sigma(H_t^{deg(v)}m_v^{t+1}) Ut?(hv?,mvt+1?)=σ(Htdeg(v)?mvt+1?),deg(v)是顶点度, H t N H_t^N HtN?则表示在 t 步且顶点度数为N下的可学习矩阵。
  • readout函数R是一个跳跃连接,将之前的隐含层 h v t h_v^t hvt?连接起来,即
    R = f ( ∑ v , t s o l f m a x ( W t h v t ) ) R=f(\sum_{v,t}solfmax(W_th^t_v)) R=f(∑v,t?solfmax(Wt?hvt?)), 这里f 是神经网络, W t W_t Wt?是针对第t步的一个可以学习的readout矩阵。
    将上面的传递函数代进通用模型,会发现 m v t + 1 = ( ∑ N ( v ) h w t , ∑ N ( v ) e v w ) m_v^{t+1}=(\sum_{N(v)}h_w^t,\sum_{N(v)}e_{vw}) mvt+1?=(∑N(v)?hwt?,∑N(v)?evw?),这样导致了分别对连通节点和连通边求和,没有合成好信息,所以这个模型最后被证实没有能力去确定边状态和结点状态的相互关系。
3.2 门图神经网络(Gated Graph Neural Networks (GG-NN))
  • M ( h v , h w , e v w ) = A e v w h w t M(h_v,h_w,e_{vw})=A_{e_{vw}}h_w^t M(hv?,hw?,evw?)=Aevw??hwt?, A e v w A_{e_{vw}} Aevw??是可学习矩阵,每类边对应一种矩阵(假定边标签为离散的边类型),
  • U t = G R U ( h v t , m v t + 1 ) U_t=GRU(h_v^t,m_v^{t+1}) Ut?=GRU(hvt?,mvt+1?),这里GRU是门循环单元,
  • R = ∑ v ∈ V σ ( i ( h v ( T ) , h v 0 ) ) ⊙ ( j ( h v ( T ) ) ) R=\sum_{v\in V}\sigma(i(h_v^{(T)},h_v^0))\odot(j(h_v^{(T)})) R=∑v∈V?σ(i(hv(T)?,hv0?))⊙(j(hv(T)?)) i 与 j 是神经网络。 ⊙ \odot ⊙表示元素乘积
3.3 Interaction Networks(Interaction networks for learning about objects, relations and physics)
  • M ( h v , h w , e v w ) M(h_v,h_w,e_{vw}) M(hv?,hw?,evw?)是一个神经网络,将concatenation ( h v , h w , e v w ) (h_v,h_w,e_{vw}) (hv?,hw?,evw?)作为输入
  • U ( h v , x v , m v ) U(h_v,x_v,m_v) U(hv?,xv?,mv?)也是一个神经网络将concatenation ( h v , x v , m v ) (h_v,x_v,m_v) (hv?,xv?,mv?)作为输入,
  • R = f ( ∑ v ∈ G h v T ) R=f(\sum_{v\in G}h_v^T) R=f(∑v∈G?hvT?),f是神经网络,将所有最后一步的隐含状态的和作为输入
3.4 Molecular Graph Convolutions(Molecular graph convolutions:Moving beyond fingerprints)
  • M ( h v t , h w t , e v w t ) = e v w t M(h_v^t,h_w^t,e_{vw}^t)=e^t_{vw} M(hvt?,hwt?,evwt?)=evwt?
  • U t ( h v t , m v t + 1 ) = α ( W 1 ( α ( W 0 h v t ) , m v t + 1 ) ) U_t(h_v^t,m_v^{t+1})=\alpha(W_1(\alpha(W_0h_v^t),m_v^{t+1})) Ut?(hvt?,mvt+1?)=α(W1?(α(W0?hvt?),mvt+1?))括号表示concatenation, α \alpha α表示ReLU, W 1 , W 2 W_1,W_2 W1?,W2?表示可学习的权重矩阵。
  • e v w t + 1 = U t ′ ( e v w t , h v t , h w t ) = α ( W 4 ( α ( W 2 , e v w t ) , α ( W 3 ( h v t , h w t ) ) ) ) e_{vw}^{t+1}=U_t'(e_{vw}^t,h_v^t,h_w^t)=\alpha(W_4(\alpha(W_2,e_{vw}^t),\alpha(W_3(h_v^t,h_w^t)))) evwt+1?=Ut′?(evwt?,hvt?,hwt?)=α(W4?(α(W2?,evwt?),α(W3?(hvt?,hwt?)))), W i W_i Wi?是可学习的权重矩阵。
3.5 Deep Tensor Neural Networks( Quantum chemical insights from deep tensor neural networks)
  • M t = t a n h ( W f c ( ( W c f h w t + b 1 ) ⊙ ( W d f e v w + b 2 ) ) ) M_t=tanh(W^{fc}((W^{cf}h_w^t+b_1)\odot(W^{df}e_{vw}+b_2))) Mt?=tanh(Wfc((Wcfhwt?+b1?)⊙(Wdfevw?+b2?))) W 是矩阵,b是偏置向量
  • U t ( h v t , m v t + 1 ) = h v t + m v t + 1 U_t(h_v^t,m_v^{t+1})=h_v^t+m_v^{t+1} Ut?(hvt?,mvt+1?)=hvt?+mvt+1?
  • R = ∑ v N N ( h v t ) R=\sum_vNN(h_v^t) R=∑v?NN(hvt?), NN 是单层神经网络,最后将每个点的输出加起来
3.6 Laplacian Based Methods(Semi-Supervised Classification with Graph Convolutional Networks. )
  1. Convolutional neural networks on graphs with
    fast localized spectral filtering 提出的
  • M t ( h v t , h w t ) = C v w t h w t M_t(h_v^t,h_w^t)=C_{vw}^th_w^t Mt?(hvt?,hwt?)=Cvwt?hwt?, C是由graph 的Laplacian 矩阵L的特征向量构成的可学习的参数矩阵。
  • U t ( h v t , m v t + 1 ) = σ ( m v t + 1 ) U_t(h_v^t,m_v^{t+1})=\sigma(m_v^{t+1}) Ut?(hvt?,mvt+1?)=σ(mvt+1?)
  1. Semi-Supervised Classification with Graph Convolutional Networks
  • M t ( h v t , h w t ) = c v w h w t = [ ( d e g ( v ) d e g ( w ) ) ? 1 / 2 A v w ] h w t M_t(h_v^t,h_w^t)=c_{vw}h_w^t=[(deg(v)deg(w))^{-1/2}A_{vw}]h_w^t Mt?(hvt?,hwt?)=cvw?hwt?=[(deg(v)deg(w))?1/2Avw?]hwt?,A是实连接矩阵
  • U v t ( h v t , m v t + 1 ) = R e L U ( W t m v t + 1 ) U_v^t(h_v^t,m_v^{t+1})=ReLU(W^tm_v^{t+1}) Uvt?(hvt?,mvt+1?)=ReLU(Wtmvt+1?)
4. MPNN变体 我们用GG-NNmodel作为baseLine
  • d 表示图中每个节点隐含表示的维度
  • n表示图节点的数量
  • 为了应用MPNNs到有向图,我们将入边和出边分辨用单独的message channel去处理,最后模型的信息传递函数就可以写成 m v i n m_v^{in} mvin?与 m v o u t m_v^{out} mvout?的concatenation
  • 当无向图当成有向图时,我们需要将入边和出边度标注相同的标签。且将message channel由 d d d改成 2 d 2d 2d
  • 模型输入:图节点的特征向量集 x v x_v xv?、一个具有向量值项的邻接矩阵A,表示分子中不同的键以及两个原子间的成对空间距离。使用GG-NN家族的message Function
4.1 信息函数(message function)
  • 矩阵乘法 : M ( h v , h w , e v w ) = A e v w h w M(h_v,h_w,e_{vw})=A_{e_{vw}}h_w M(hv?,hw?,evw?)=Aevw??hw?
  • 边网络:边信息传递函数: M ( h v , h w , e v w ) = A ( e v w ) h w M(h_v,h_w,e_{vw})=A(e_{vw})h_w M(hv?,hw?,evw?)=A(evw?)hw?, 这里A是神经网络,将边向量 e v w e_{vw} evw?映射到 d × d d\times d d×d矩阵
  • 成对信息: 理论上来说,如果节点信息传递函数计算的时候同时利用出边节点和入边节点的信息将会更加的高效(比起只利用出边结点和边信息)。所以这里采用了利用三者的信息:m w v = f ( h w t , h v t , e v w ) m_{wv}=f(h_w^t,h_v^t,e_{vw}) mwv?=f(hwt?,hvt?,evw?), 这里 f 是一个神经网络
  • 有向图:将上面的信息函数运用到有向图时,我们要用两个独立的message函数, M i n , M o u t M^{in},M^{out} Min,Mout, e v w e_{vw} evw?决定于边的方向。
4.2 虚拟图元素(很不错的idea)
  1. 添加虚拟边,将距离比较近的结点认为加上边,使其在传播时可以访问更长的距离
  2. 添加“主”结点
    即认为主节点与所有输入节点都有特殊的边。在消息传递的每个步骤中,主节点充当一个全局起始空间,每个节点对它进行读写操作。可以让组节点有其特定的维度 d m a s t e r d_{master} dmaster?,并且在内部更新函数中有其独立的参数。
4.3 readout function 这里尝试了两种readout function。
  • 第一种是用的GG-NN中的readout函数,R = ∑ v ∈ V σ ( i ( h v ( T ) , h v 0 ) ) ⊙ ( j ( h v ( T ) ) ) R=\sum_{v\in V}\sigma(i(h_v^{(T)},h_v^0))\odot(j(h_v^{(T)})) R=∑v∈V?σ(i(hv(T)?,hv0?))⊙(j(hv(T)?))
  • 第二种用的set2set( Sequence to sequence for sets.)中的readout函数,先用线性变换将每个元组 ( h v T , x v ) (h_v^T,x_v) (hvT?,xv?),然后合成为集合 T = { ( h v T , x v ) } T=\{(h_v^T,x_v)\} T={(hvT?,xv?)},然后经过M步的计算,产生了图的embedingq t ? q^*_t qt??,然后就把其喂到神经网络中产生输出。
    这个是set2set的公式:
    图神经网络|MPNN(message passing neural networks)
    文章图片
Multiply Towers MPNNs的一个重要特点是其的可扩展性,但是对于dense graph,单步信息传递过程需要 O ( n 2 d 2 ) O(n^2d^2) O(n2d2),所以为了减少计算量,将 h v t h_v^t hvt?除于k 变成d/k维向量 { h ~ v t + 1 , v ∈ G } \{\tilde{h}_v^{t+1},v\in G\} {h~vt+1?,v∈G},最后k 个顶点的临时嵌入被混合在一起
图神经网络|MPNN(message passing neural networks)
文章图片
g 是神经网络, ( x , y , . . . ) (x,y,...) (x,y,...)表示concatenation。g 共享图中所有的结点,这种混合可以保证节点排列的不变性。

    推荐阅读