登山则情满于山,观海则意溢于海。这篇文章主要讲述比较图神经网络PyTorch Geometric 与 Deep Graph Library,帮助团队选出适合的GNN库相关的知识,希望能为你提供帮助。
文章图片
PART 01
开篇
本文比较了Deep Graph Library (DGL) 和 PyTorch Geometric 这两个图神经网络,以帮助你选择适合团队的GNN库。
PART 02
图神经网络比较
DGL 与 PyTorch Geometric
什么是基于图的深度学习?一般来说,图是由边和节点连接形成的系统,而节点则具有某种内部状态,通过连接节点的边所定义的当前节点与其他节点的关系来修改,同时这些连接和节点的状态还可以以多种方式定义。
深度学习是对数据重复进行非线性变换,通常的做法是矩阵乘法或卷积。将深度学习和图相结合也就有了快速发展的图神经网络 (GNN) 领域。图为任何由节点和关系定义的系统提供了一个有用的框架,包括社交网络、分子和许多其他类型的理论系统。处理定义为图的数据可为所研究的系统提供有意义的结构数据,以及大量可用的数学和算法工具。更重要的是,由于图是可以被矩阵数学进行描述和操作,图也成为了深度学习领域的重要补充,并从多年来主要使用相同数学原语的快速深度学习库的发展中大大受益。
文章图片
邻接矩阵表示图中的边连接。rivesunder的公共领域图?
图为许多类型的问题提供了丰富的框架,也使深度学习神经网络在过去几十年中取得的巨大成功。毫无疑问,图神经网络(GNN)越来越受到关注,并实现了自身突破。到目前为止,图深度学习最令人兴奋的成就是 DeepMind 开发的 AlphaFold 和 AlphaFold2,该项目在解决结构生物学长期以来的蛋白质结构预测问题方面取得了重大进展。由于在药物发现、社交网络、基础生物学和许多其他领域中也有无数重要应用,人们已经开发了许多用于处理图神经网络的开源库。其中许多开源库已经足够成熟,并且可以在生产或研究中使用,而要在开始一个新项目时选择合适的库去使用,要考虑的因素很多。最重要的GNN库选择依据是与团队现有专业知识的兼容性:如果你熟悉PyTorch的使用,那么PyTorch Geometric是不错的选择,尽管你可能也使用带有 PyTorch的DGL作为后端(DGL 也可以使用 TensorFlow 作为后端)。同样,如果你更熟悉 TensorFlow 和 Keras,Spektral可能会更有价值。如果你想使用新兴的 JAX 生态系统进行开发,那么Jraph可能非常适合你的 GNN 项目。当然,如果你的团队更喜欢 Julia 而不是 python,你可能更希望着眼于GeometricFlux.jl或GraphNeuralNetworks.jl,它们都基于 Flux.jl 机器学习生态系统。使用 Julia 编写的其他工具和 Julia 编程语言本身一样,GeometricFlux.jl 和 GraphNeuralNetworks.jl 并不像更成熟的 Python 同类工具那样知名,并且其对应的社区更加小众,尽管它们也确实有一些引人注目的优势。基于 Julia 的工具的优势之一是执行速度,这要归功于 Julia 内置的“即时”编译。虽然当下提供的语言和类似库(如PyTorch或TensorFlow)在完成给定任务方面显著提高了开发效率,但对机器学习项目效率的另一个主要贡献就是计算速度。与代码本身的执行速度相比,开发人员时间往往是一种稀缺资源,在机器学习项目中往往被低估。但想象一下,如果没有广泛可用的库运用到GPU上实现有效的硬件加速,深度学习恐怕需要很长时间才能被成熟应用。然而,如果没有高效软件的配合,世界上所有的专用硬件都不能发挥最大效用。本文将对关注度最高的两个开放源码库进行基准测试和比较,比较的范围设定为图形神经网络的计算。为了进行比较,将重点介绍Python库 PyTorch Gemetric 和 Deep Graph Library(深度图库 DGL)。顾名思义,Pytork Geometric是基于Pytorh的(加上许多用于处理稀疏矩阵的Pytorh扩展),而DGL可以使用Pytorh或TensorFlow作为其后端。DGL 被用于开发 SE3-Transformer,它是一种平移和旋转不变模型,该模型对蛋白质结构预测冠军模型 - AlphaFold有很大影响。Baker 实验室受 DeepMind 工作启发,将DGL用于开源RosettaFold的蛋白质结构预测。PyTorch Geometric是一个相当流行的库,它在GitHub 上有 13,000 多颗星,为所有有 PyTorch 经验的人提供了方便熟悉的 API。我们将介绍每个API,并对Zitnik 和 Leskovec 2017 年论文中蛋白质与蛋白质相互作用 (protein-protein interaction, PPI) 数据集中的等效 GNN 架构进行基准测试。PPI 数据集呈现了一个多类节点分类任务,每个节点代表一个由 50 个特征组成的蛋白质,并用 121 个非排他标签进行标记。如果你使用深度学习模型已有一段时间,可能也见证了 Python 库的兴衰。在谷歌2015 年发布开源库TensorFlow后,TensorFlow得到了广泛采用。此前,深度学习库的布局是由 Theano、AutoGrad、Caffe、Lasagne、Keras、Chainer 等框架组成的多元事务。在此期间,深度学习的库都是自主研发的。如果想要 GPU的 支持,就必须了解 CUDA。PyTorch 于 2016 年推出,虽然发展缓慢,但无疑将成为了深度学习的首选库,同时TensorFlow 吞并了 Keras,依旧受到生产流程的青睐。到 TensorFlow 发布 2.0 版时,似乎就成了 PyTorch 和 TensorFlow “两个库”之间的游戏。它们之间的差异越来越小,TensorFlow 变得像 PyTorch 一样更加动态,PyTorch 通过即时编译和Torchscript变得更快。也许是因为两大类库逐渐趋同,学术项目 Autograd 后续的 JAX 为具备能力、功能性和可组合性为一身的深度学习找到了一个开放的定位,DeepMind 等主要实验室也正保持对 JAX 的关注。Jraph 是DeepMind 对基于图深度学习的基于 JAX 的解决方案,尽管该方案与TensorFlow项目Graph Nets有诸多特征相似之处(在撰写本文时,Graph Nets项目一年多没有更新了)。在下一节中,我们将了解如何安装和设置 DGL 和 PyTorch Geometric,以及如何使用每个库构建具有 6 个隐藏层的图卷积网络。还将在 PPI 数据集上的节点分类建立一个训练循环,并讨论每个人使用的图形数据结构 API 的差异。最后,我们在单个 NVIDIA GPU 上执行 10,000 个 epoch 训练并比较每个 GPU 的速度。
PART 03
PyTorch Geometric
PyTorch Geometric (PyG) 是一个直观的库,很像标准的 PyTorch。数据集和数据加载器具有一致的 API,因此无需针对不同任务手动调整模型架构。
1、安装请注意,我们基于 pip、PyTorch 1.10、Python 3.6 和 CUDA 10.2 的系统设置安装了 PyTorch Geometric。
文章图片
2、模型和代码我们将使用基准测试库的架构,该架构是基于 Kipf 和 Welling 在其 2016 年论文中描述的图卷积层(PyTorch Geometric 中的 GCNConv 和 DGL 中的 GraphConv)。PyTorch Geometric 的图形层使用的 API 与 PyTorch 的非常相似,但将使用 PyTorch Geometric Batch 类的 edge_index 中的图形边作为输入。库中的批次将一个或多个图的聚集描述为一个有内部间隙的大图。对于图卷积,这些批次使用矩阵乘法和组合的邻接矩阵来实现权重共享,但 Batch 对象还在一个称为批次的变量中跟踪节点与图的对照关系。我们将使用的图卷积模型如下图所示:
文章图片
用于对 DGL 和 PyTorch Geometric 进行基准测试的图卷积网络图在代码中,我们的模型是通过继承 PyTorch 的 torch.nn.Module 模型类构建的。
文章图片
请注意,这个模型没有将张量作为输入,而是采用一个名为“batch”的变量,有时在常见的样式约定中称为“data”。该批次包含定义节点与图的对应关系以及这些节点如何连接的额外信息。除了这种差异之外,该模型读起来很像标准卷积网络,使用 GCNConv 或类似基于图的层则不是标准卷积。PyTorch 用户也非常熟悉训练循环,但它是将整个批次传递给模型,而不是单独的输入张量。但在进入训练循环之前,我们需要下载蛋白质与蛋白质相互作用 (PPI) 数据集并设置训练和测试数据加载器。
文章图片
现在我们准备好定义训练循环了,此外还跟踪每个epoch的时间和损失。
文章图片
如此, 该代码已准备好在 PPI 数据集上对 PyTorch Geometric 进行基准测试。有了这些,在 Deep Graph Library 中构建等效模型会容易得多,这与我们下一节中讨论的代码会有一些差异。
PART 04
深度图库
Deep Graph Library, DGL
Deep Graph Library 是一个灵活的库,可以利用 PyTorch 或 TensorFlow 作为后端。我们将使用 PyTorch 进行此演示,但如果你常用 TensorFlow 并希望将其用于图形的深度学习,你可以通过将“TensorFlow”导出到名为 DGLBACKEND 的环境变量来实现。至少,可以将代码调整为来自 tf.keras.Model 而不是 torch.nn.Module 的子类,并使用来自 keras 中API 的fit 方法。
1、安装同样,在使用 CUDA 10.2 的系统上进行安装,但如果你已升级到 CUDA 11 或仍停留在 CUDA 10.1 上,你可以从DGL 网站获取正确的 pip install 命令。
文章图片
2、模型和代码在 DGL 中,Kipf 和 Welling 图卷积层称为“GraphConv”,而不是 PyTorch Geometric 中使用的“GCNConv”。除此之外,该模型看起来基本相同。
文章图片
请注意,我们不是传递批处理作为输入,而是传递 g(DGL 图对象)和节点特征。在设置训练循环时,也大致相同,但要特别注意传递给模型的内容。在这种情况下,节点特征于 batch.ndata["feat"] 中找到,但我们发现用于节点特征的特定键因数据集而异。feat 可能是最常见的,但你也会发现 node_attr 和其他内容,不一致的 API 可能有点令人困惑。这确实是一个痛点,因为我们为此演示尝试了不同的内置数据集,并且重写代码位以适应不同的数据集确实会减慢开发速度。我们自然更喜欢PyTorch Geometric 中使用的“批处理”(Batch) 对象的一致性。在实践中,一致的内部样式在实际应用中并不会造成问题,因为无论如何你都不会使用内置数据集。
文章图片
文章图片
3、结果
文章图片
PPI 数据集上 DGL 和 PyTorch 几何图卷积网络的训练曲线?
我们在 PPI 数据集上使用 PyTorch Geometric 和 DGL 完成了 10,000 个 epoch 的训练,在单个 NVIDIA GTX 1060 GPU 上运行。PyG 用了 2,984.34 秒完成训练,而 DGL 用时不到一半,为 1,148.05 秒。两次运行都以相似的性能结束,PyG 的测试准确度为 73.35%,DGL 的测试准确度为 77.38%,我们期望通过随机初始化的方式让每次运行发生一些偶然情况。在 10,000 个 epochs 之后,损失仍在减少,因此你可以预期此模型架构最终会收敛到稍高的准确度(尽管我们没有跟踪此实验的验证损失)。
PART 05
哪个 GNN 库最适合自己
不同库训练时间的巨大差异让人们惊讶。鉴于 DGL 和 PyG 都是基于 PyTorch 构建的,或者使用 PyTorch 作为计算后端,预计它们都会在10% 或 20%内完成。
【比较图神经网络PyTorch Geometric 与 Deep Graph Library,帮助团队选出适合的GNN库】因为批处理 API 比 DGL 中的等效 API 更加直观和一致,我们确实发现使用 PyTorch Geometric 设置训练循环比使用 DGL 更舒服一些。在确定 PPI 数据集之前,我们尝试了几个不同的数据集,似乎每个数据集都使用不同的键来使DGL 检索节点特征。话虽如此,在使用DGL中遇到的一些小麻烦可能与库的熟悉程度有关,毕竟我们使用 PyG 的时间比 DGL 要长。不同架构和层类型的性能确实值得研究,但仅靠选择正确的库很难将性能提升 2 倍。尽管开发人员的时间比模型计算时间更加稀缺,但在此例的设置中 DGL 的速度快了近 2.6 倍,凭这个优势就值得训练,并切换模型库。同时,在DGL 生态系统中遇到小问题,也会随着熟悉程度的提高而解决。虽然从 GitHub 星数和分支数就能看出来(13,700/2,400 DGL vs 8,800/2,000 PyTorch),DGL似乎不如 PyTorch Geometric那么流行,但大量社区支持和丰富的文档可以保障DGL库的易学性,同时也可以帮助解决出现的问题。无论选哪个,任何可以在网络生存数据的环境下,通过学习图中编码的结构信息都能提供诸多学习内容,而硬件和软件对稀疏矩阵快速计算的支持与改进也会让GNN 库的投资变得更有价值。原文链接:https://dzone.com/articles/pytorch-geometric-vs-deep-graph-library
推荐阅读
- DevEco Studio 3.0 Beta2 for HarmonyOS下载与安装
- Linux 应急响应
- 领域驱动设计 - 战略设计 - 1/2限界上下文
- OpenHarmony 通话应用源码剖析
- 知名网站的404页面都长啥样(最后一个绝了...)
- 进程管理二
- Java中Math.random()与Random类生成随机数及源码分析
- 40 岁失业高级码农自曝(阿里 P9,攒了 1.5亿...)
- Java设计模式—适配器模式(adapter pattern)