【图神经网络|图神经网络(一)—GraphSAGE-pytorch版本代码详解】
GraphSAGE代码详解-pytorch版本
-
-
- 1. GraphSAGE基本介绍
- 2. 代码解析
-
- 2.1 加载数据
- 2.2 Unsupervised Loss
- 2.3 Models
- 2.4 评估与模型使用
- 2.5 Main
- 参考资料
-
写在前面…
在研究生的工作中使用到了图神经网络,所以平时会看一些与图神经网络相关的论文和代码。写这个系列的目的是为了帮助自己再理一遍算法的基本思想和流程,如果同时也能对其他人提供帮助是极好的~博主也是在学习过程中,有些地方有误还请大家批评指正!
- github: https://github.com/OuYangg/GNNs
- 论文标题:Inductive Representation Learning on Large Graphs
- 作者:William L. Hamilton, Rex Ying and Jure Leskovec
在GraphSAGE之前提出的图神经网络方法,如GCN,都属于transductive模型,此类模型在当网络的结构稍微出现一点改变,就需要重新训练,无法满足实时快速产生网络节点嵌入的需求。为了解决这一问题,Jure大佬等人提出了一个infuctive模型,那就是GraphSAGE。GraphSAGE的目标是训练多个aggregator以聚合目标节点不同阶的邻居节点信息,从而可以快速生成未知节点的低维向量表示。
- GraphSAGE的基本流程见下图:
文章图片
1)首先通过随机游走获得固定大小的邻域网络 2)然后通过aggregator把有限阶邻居节点的特征聚合给目标节点,伪代码如下
文章图片
由上面的伪代码可见,GraphSAGE的输入为:目标网络 G G G、节点的特征向量 x v x_v xv?、权重矩阵 W k W^k Wk、非线性激活函数 σ \sigma σ、aggregator函数以及邻居函数 N N N. 1)首先 h 0 h_0 h0?为节点的特征向量,循环 K K K步 2)遍历每个节点,对于每个节点首先聚合邻居节点 k ? 1 k-1 k?1时刻的特征,然后将聚合的结果与当前节点 k ? 1 k-1 k?1时刻的特征进行concat并经过一个激活函数 3)当循环完K步之后进行一个令 h k h_k hk?除以 ∣ ∣ h k ∣ ∣ 2 ||h_k||_2 ∣∣hk?∣∣2?得到节点的低维表示。
- Aggregators: 从上面的流程图和伪代码可见,GraphSAGE需要用到aggregators,那这个aggregator是什么呢?起到什么作用呢?其实aggregator的作用就是将目标节点的邻居信息进行一个聚合,作者在文中给出了3种不同的aggregators分别是:
1)Mean aggregator: 该策略是将邻居节点与目标节点特征向量的值取平均
文章图片
2)LSTM aggregator: 利用LSTM来聚合邻居节点的信息。
3)Pooling aggregator: 在使用pooling聚合器的时候,每个邻居节点的特征逐一的经过一个全连接层,从而进行池化操作
文章图片
这里的max是一个element-wise的max.
- 代码参考地址:graphSAGE-pytorch
- 导入所需的库
import numpy as np
import pandas as pd
import os,sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import math
from sklearn.utils import shuffle
from sklearn.metrics import f1_score
from collections import defaultdict
2.1 加载数据
- 本代码使用的数据cora共包含两个文件,分别是cora_content和cora_cite. 数据详细描述可参考链接:Cora数据集描述
class DataCenter(object):
"""加载数据集
Parameter:
file_paths:{数据文件存放地址1,数据文件存放地址2}
"""
def __init__(self,file_paths):
"""file_paths:{name:root,...,}"""
super(DataCenter,self).__init__()
self.file_paths = file_pathsdef load_Dataset(self,dataset='cora'):
"""读取存放在指定路径的数据集"""
feat_list = [] # 用于存放每个节点特征向量的列表
label_list = [] # 用于存放每个节点对应类别的列表
node_map = {} # 将节点进行重新编码
label_map = {} # 将label映射为数字if dataset == 'cora':
content = self.file_paths['cora_content'] # 获取cora_content的地址
cite = self.file_paths['cora_cite'] # 获取cora_cite的地址
with open(content) as f1:
for i,each_sample in enumerate(f1.readlines()): # 遍历每个样本的特征
sample_clean = each_sample.strip().split()
# 提取每个样本的特征,其中第一个元素和最后一个元素是样本名称和对应的标签
feat_list.append(sample_clean[1:-1])
# 把节点名称映射为节点编号
node_map[sample_clean[0]]=i
label = sample_clean[-1]
if label not in label_map.keys():
# 把label转化为数字
label_map[label] = len(label_map)
label_list.append(label_map[label])
feat_list = np.asarray(feat_list,dtype=np.float64)
label_list = np.asarray(label_list,dtype=np.int64)# 获得每个节点的邻居{v0:[v0的邻居集合],v1:[v1的邻居集合]}
adj_lists = defaultdict(set)
with open(cite) as f2:
for j,each_pair in enumerate(f2.readlines()):
pair = each_pair.strip().split()
assert len(pair) == 2
adj_lists[node_map[pair[0]]].add(node_map[pair[1]])
adj_lists[node_map[pair[1]]].add(node_map[pair[0]])assert len(feat_list) == len(label_list) == len(adj_lists)
train_index,test_index,val_index = self._split_data(feat_list.shape[0])
# 使用getattr()可以获得数据
setattr(self,dataset+'_test',test_index)
setattr(self,dataset+'_val',val_index)
setattr(self,dataset+'_train',train_index)
setattr(self,dataset+'_feats',feat_list)
setattr(self,dataset+'_labels',label_list)
setattr(self,dataset+'_adj_lists',adj_lists)def _split_data(self,number_of_nodes,test_split=3,val_split=6):
"""获得训练集、验证集和测试集"""
# 打乱顺序
rand_indices = np.random.permutation(number_of_nodes)
test_size = number_of_nodes // test_split
val_size = number_of_nodes // val_split
test_index = rand_indices[:test_size]
val_index = rand_indices[test_size:test_size+val_size]
train_index = rand_indices[test_size+val_size:]
return train_index,test_index,val_index
2.2 Unsupervised Loss GraphSAGE定义的Loss函数如下所示, J G ( z u ) = ? l o g ( σ ( z u T z v ) ) ? Q E v nP n ( v ) l o g ( σ ( ? z u T z v n ) ) J_G(z_u)=-log(\sigma(z_u^Tz_v))-QE_{v_n ~ P_n(v)}log(\sigma(-z_u^Tz_{v_n})) JG?(zu?)=?log(σ(zuT?zv?))?QEvn? Pn?(v)?log(σ(?zuT?zvn??))
其中, Q Q Q为负样本数量,前一项是根据正样本计算的Loss,后一项是根据负样本计算的Loss.
class UnsupervisedLoss(object):
"""docstring for UnsupervisedLoss"""
def __init__(self, adj_lists, train_nodes, device):
"""初始化参数"""
super(UnsupervisedLoss, self).__init__()
self.Q = 10 # 负样本的数量
self.N_WALKS = 6 # 每个节点随机游走的次数
self.WALK_LEN = 1 # 每次随机游走的步长
self.N_WALK_LEN = 5 # 每次负样本随机游走几个节点
self.MARGIN = 3
self.adj_lists = adj_lists #{v0:[v0的邻居集合],v1:[v1的邻居集合],...,vn:[vn的邻居集合]}
self.train_nodes = train_nodes # 训练节点
self.device = device # cpu or gpuself.target_nodes = None
self.positive_pairs = [] # 存放正例样本 [(v0,v0邻居中采样到的正例节点),....,]
self.negtive_pairs = [] # 存放负例样本 [(v0,v0邻居中采样到的负例节点),....,]
self.node_positive_pairs = {} # {v0:[(v0,从v0开始随机游走采样到的正例节点),(v0,从v0开始随机游走采样到的正例节点)],...,vn:[(vn,从vn开始随机游走采样到的正例节点)]}
self.node_negtive_pairs = {} # {v0:[(v0,从v0开始随机游走采样到的负例节点),(v0,从v0开始随机游走采样到的负例节点)],...,vn:[(vn,从vn开始随机游走采样到的负例节点)]}
self.unique_nodes_batch = [] # 一个batch所有会用到的节点及其邻居节点 def get_loss_sage(self, embeddings, nodes):
"""根据论文里的公式计算损失函数"""
assert len(embeddings) == len(self.unique_nodes_batch) #判断是不是每个节点都有了embeddings
assert False not in [nodes[i]==self.unique_nodes_batch[i] for i in range(len(nodes))] # 判断目标节点集和unique集里的节点是否1一一对应
node2index = {n:i for i,n in enumerate(self.unique_nodes_batch)} # 把节点重新编码nodes_score = []
assert len(self.node_positive_pairs) == len(self.node_negtive_pairs) # 确定正例节点对和负例节点对的数量是否相同
for node in self.node_positive_pairs: # 遍历所有节点
pps = self.node_positive_pairs[node] # 获得对应的正例 [(v0,v0正例样本1),(v0,v0正例样本2),...,(v0,v0正例样本n)]
nps = self.node_negtive_pairs[node] # 获得每个节点对应的负例 [(v0,v0负例样本1),(v0,v0负例样本2),...,(v0,v0负例样本n)]
if len(pps) == 0 or len(nps) == 0: # 判断是否都有正例和负例
continue# Q * Exception(negative score)计算负例样本的Loss,即Loss函数的后一项
indexs = [list(x) for x in zip(*nps)] # [[源节点,...,源节点],[采样得到的负节点1,...,采样得到的负节点n]]
node_indexs = [node2index[x] for x in indexs[0]] # 获得源节点的编号
neighb_indexs = [node2index[x] for x in indexs[1]] # 负样本节点的编号
neg_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs]) # 计算余弦相似性
neg_score = self.Q*torch.mean(torch.log(torch.sigmoid(-neg_score)), 0) # 计算损失的后一项
#print(neg_score)# multiple positive score 计算正列样本的Loss,即Loss函数的前一项
indexs = [list(x) for x in zip(*pps)]
node_indexs = [node2index[x] for x in indexs[0]]
neighb_indexs = [node2index[x] for x in indexs[1]]
pos_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs])
pos_score = torch.log(torch.sigmoid(pos_score)) # 计算损失的前一项
#print(pos_score)nodes_score.append(torch.mean(- pos_score - neg_score).view(1,-1)) # 把每个节点的损失加入到列表中loss = torch.mean(torch.cat(nodes_score, 0)) # 求平均return loss def get_loss_margin(self, embeddings, nodes):
assert len(embeddings) == len(self.unique_nodes_batch)
assert False not in [nodes[i]==self.unique_nodes_batch[i] for i in range(len(nodes))]
node2index = {n:i for i,n in enumerate(self.unique_nodes_batch)}nodes_score = []
assert len(self.node_positive_pairs) == len(self.node_negtive_pairs)
for node in self.node_positive_pairs:
pps = self.node_positive_pairs[node]
nps = self.node_negtive_pairs[node]
if len(pps) == 0 or len(nps) == 0:
continueindexs = [list(x) for x in zip(*pps)]
node_indexs = [node2index[x] for x in indexs[0]]
neighb_indexs = [node2index[x] for x in indexs[1]]
pos_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs])
pos_score, _ = torch.min(torch.log(torch.sigmoid(pos_score)), 0)indexs = [list(x) for x in zip(*nps)]
node_indexs = [node2index[x] for x in indexs[0]]
neighb_indexs = [node2index[x] for x in indexs[1]]
neg_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs])
neg_score, _ = torch.max(torch.log(torch.sigmoid(neg_score)), 0)nodes_score.append(torch.max(torch.tensor(0.0).to(self.device), neg_score-pos_score+self.MARGIN).view(1,-1))
# nodes_score.append((-pos_score - neg_score).view(1,-1))loss = torch.mean(torch.cat(nodes_score, 0),0)# loss = -torch.log(torch.sigmoid(pos_score))-4*torch.log(torch.sigmoid(-neg_score))return loss def extend_nodes(self, nodes, num_neg=6):
"""获得目标节点集的正样本和负样本,输出这些节点的集合"""
self.positive_pairs = []
self.node_positive_pairs = {}
self.negtive_pairs = []
self.node_negtive_pairs = {}self.target_nodes = nodes
self.get_positive_nodes(nodes)
# print(self.positive_pairs)
self.get_negtive_nodes(nodes, num_neg)
# print(self.negtive_pairs)
self.unique_nodes_batch = list(set([i for x in self.positive_pairs for i in x]) | set([i for x in self.negtive_pairs for i in x]))
assert set(self.target_nodes) < set(self.unique_nodes_batch)
return self.unique_nodes_batch def get_positive_nodes(self, nodes):
return self._run_random_walks(nodes) # 通过随机游走获得正列样本 def get_negtive_nodes(self, nodes, num_neg):
"""
生成负样本,也就是让目标节点与目标节点相隔很远的节点组成一个负例
"""
for node in nodes: # 遍历每个节点
neighbors = set([node])
frontier = set([node])
for i in range(self.N_WALK_LEN):
current = set()
for outer in frontier:
current |= self.adj_lists[int(outer)] #获取frontier中所有的邻居节点
frontier = current - neighbors #去除源节点
neighbors |= current # 源节点+邻居节点
far_nodes = set(self.train_nodes) - neighbors # 减去train_nodes里源节点及其一阶邻居
neg_samples = random.sample(far_nodes, num_neg) if num_neg < len(far_nodes) else far_nodes # 从二阶邻居开始采样
self.negtive_pairs.extend([(node, neg_node) for neg_node in neg_samples])
self.node_negtive_pairs[node] = [(node, neg_node) for neg_node in neg_samples]
return self.negtive_pairs def _run_random_walks(self, nodes):
for node in nodes: # 遍历每个节点
if len(self.adj_lists[int(node)]) == 0: # 若该节点没有邻居节点则跳过
continue
cur_pairs = [] # 创建一个
for i in range(self.N_WALKS): # 每个节点会有N_WALKS次的随机游走
curr_node = node #
for j in range(self.WALK_LEN): # 每次随机游走走WALK_LEN的长度
neighs = self.adj_lists[int(curr_node)]
next_node = random.choice(list(neighs))
# self co-occurrences are useless
if next_node != node and next_node in self.train_nodes:
self.positive_pairs.append((node,next_node))
cur_pairs.append((node,next_node))
curr_node = next_nodeself.node_positive_pairs[node] = cur_pairs
return self.positive_pairs
2.3 Models
- Classification model
class Classification(nn.Module):
"""一个最简单的一层分类模型
Parameters:
input_size:输入维度
num_classes:类别数量
return:
logists:最大概率对应的标签
"""
def __init__(self,input_size,num_classes):
super(Classification,self).__init__()
self.fc1 = nn.Linear(input_size,num_classes) # 定义一个input_size*num_classes的线性层
self.init_params() # 初始化权重参数def init_params(self):
for param in self.parameters():
if len(param.size()) == 2: # 如果参数是矩阵的话就重新初始化
nn.init.xavier_uniform_(param) def forward(self,x):
logists = torch.log_softmax(self.fc1(x),1) # 利用log_softmax来获得最终输出的类别
return logists
- GraphSAGE
class SageLayer(nn.Module):
"""
一层SageLayer
"""
def __init__(self, input_size, out_size, gcn=False):
super(SageLayer, self).__init__()
self.input_size = input_size
self.out_size = out_size
self.gcn = gcn
self.weight = nn.Parameter(torch.FloatTensor(out_size, self.input_size if self.gcn else 2 * self.input_size)) #初始化权重参数w*input.T
self.init_params() # 调整权重参数分布 def init_params(self):
for param in self.parameters():
nn.init.xavier_uniform_(param) def forward(self, self_feats, aggregate_feats, neighs=None):
"""
Parameters:
self_feats:源节点的特征向量
aggregate_feats:聚合后的邻居节点特征
"""
if not self.gcn: # 如果不是gcn的话就要进行concatenate
combined = torch.cat([self_feats, aggregate_feats], dim=1)
else:
combined = aggregate_feats
combined = F.relu(self.weight.mm(combined.t())).t()
return combinedclass GraphSage(nn.Module):
"""定义一个GraphSage模型"""
def __init__(self, num_layers, input_size, out_size, raw_features, adj_lists, device, gcn=False, agg_func='MEAN'):
super(GraphSage, self).__init__()
self.input_size = input_size
self.out_size = out_size
self.num_layers = num_layers # Graphsage的层数
self.gcn = gcn
self.device = device
self.agg_func = agg_func
self.raw_features = raw_features
self.adj_lists = adj_lists
# 定义每一层的输入和输出
for index in range(1, num_layers+1):
layer_size = out_size if index != 1 else input_size
setattr(self, 'sage_layer'+str(index), SageLayer(layer_size, out_size, gcn=self.gcn))#除了第1层的输入为input_size,其余层的输入和输出均为outsize def forward(self, nodes_batch):
"""
为一批节点生成嵌入表示
Parameters:
nodes_batch:目标批次的节点
"""
lower_layer_nodes = list(nodes_batch) # 初始化第一层节点
nodes_batch_layers = [(lower_layer_nodes,)] # 存放每一层的节点信息
for i in range(self.num_layers):
lower_samp_neighs, lower_layer_nodes_dict, lower_layer_nodes= self._get_unique_neighs_list(lower_layer_nodes) # 根据当前层节点获得下一层节点
nodes_batch_layers.insert(0, (lower_layer_nodes, lower_samp_neighs, lower_layer_nodes_dict))assert len(nodes_batch_layers) == self.num_layers + 1pre_hidden_embs = self.raw_features # 初始化h0
for index in range(1, self.num_layers+1):
nb = nodes_batch_layers[index][0]#所有邻居节点
pre_neighs = nodes_batch_layers[index-1] # 上一层的邻居节点
aggregate_feats = self.aggregate(nb, pre_hidden_embs, pre_neighs)
sage_layer = getattr(self, 'sage_layer'+str(index))
if index > 1:
nb = self._nodes_map(nb, pre_hidden_embs, pre_neighs)
# self.dc.logger.info('sage_layer.')
cur_hidden_embs = sage_layer(self_feats=pre_hidden_embs[nb],
aggregate_feats=aggregate_feats)
pre_hidden_embs = cur_hidden_embsreturn pre_hidden_embs def _nodes_map(self, nodes, hidden_embs, neighs):
layer_nodes, samp_neighs, layer_nodes_dict = neighs
assert len(samp_neighs) == len(nodes)
index = [layer_nodes_dict[x] for x in nodes]
return index def _get_unique_neighs_list(self, nodes, num_sample=10):
_set = set
to_neighs = [self.adj_lists[int(node)] for node in nodes] # 获取目标节点集的所有邻居节点[[v0的邻居],[v1的邻居],[v2的邻居]]
if not num_sample is None: # 如果num_sample为实数的话
_sample = random.sample
samp_neighs = [_set(_sample(to_neigh, num_sample)) if len(to_neigh) >= num_sample else to_neigh for to_neigh in to_neighs] # [set(随机采样的邻居集合),set(),set()]
# 遍历所有邻居集合如果邻居节点数>=num_sample,就从邻居节点集中随机采样num_sample个邻居节点,否则直接把邻居节点集放进去
else:
samp_neighs = to_neighs
samp_neighs = [samp_neigh | set([nodes[i]]) for i, samp_neigh in enumerate(samp_neighs)] # 把源节点也放进去
_unique_nodes_list = list(set.union(*samp_neighs)) #展平
i = list(range(len(_unique_nodes_list))) # 重新编号
unique_nodes = dict(list(zip(_unique_nodes_list, i)))
return samp_neighs, unique_nodes, _unique_nodes_list def aggregate(self, nodes, pre_hidden_embs, pre_neighs, num_sample=10):
"""聚合邻居节点信息
Parameters:
nodes:从最外层开始的节点集合
pre_hidden_embs:上一层的节点嵌入
pre_neighs:上一层的节点
"""
unique_nodes_list, samp_neighs, unique_nodes = pre_neighs # 上一层的源节点,...,....,
assert len(nodes) == len(samp_neighs)
indicator = [(nodes[i] in samp_neighs[i]) for i in range(len(samp_neighs))] # 判断每个节点是否出现在邻居节点中
assert (False not in indicator)
if not self.gcn:
# 如果不适用gcn就要把源节点去除
samp_neighs = [(samp_neighs[i]-set([nodes[i]])) for i in range(len(samp_neighs))]
if len(pre_hidden_embs) == len(unique_nodes):
embed_matrix = pre_hidden_embs
else:
embed_matrix = pre_hidden_embs[torch.LongTensor(unique_nodes_list)]
# self.dc.logger.info('3')
mask = torch.zeros(len(samp_neighs), len(unique_nodes))
column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
mask[row_indices, column_indices] = 1 # 每个源节点为一行,一行元素中1对应的就是邻居节点的位置if self.agg_func == 'MEAN':
num_neigh = mask.sum(1, keepdim=True) # 计算每个源节点有多少个邻居节点
mask = mask.div(num_neigh).to(embed_matrix.device) #
aggregate_feats = mask.mm(embed_matrix)elif self.agg_func == 'MAX':
# print(mask)
indexs = [x.nonzero() for x in mask==1]
aggregate_feats = []
for feat in [embed_matrix[x.squeeze()] for x in indexs]:
if len(feat.size()) == 1:
aggregate_feats.append(feat.view(1, -1))
else:
aggregate_feats.append(torch.max(feat,0)[0].view(1, -1))
aggregate_feats = torch.cat(aggregate_feats, 0)
return aggregate_feats
2.4 评估与模型使用
def evaluate(dataCenter, ds, graphSage, classification, device, max_vali_f1, name, cur_epoch):
"""
测试模型的性能
Parameters:
datacenter:创建好的datacenter对像
ds:数据集的名称
graphSage:训练好的graphSage对像
classification:训练好的classificator"""
test_nodes = getattr(dataCenter, ds+'_test') # 获得测试集
val_nodes = getattr(dataCenter, ds+'_val') # 获得验证集
labels = getattr(dataCenter, ds+'_labels') # 获得标签 models = [graphSage, classification] params = [] # 将两个模型的参数存入一个列表中
for model in models:
for param in model.parameters():
if param.requires_grad:
param.requires_grad = False
params.append(param) embs = graphSage(val_nodes)
logists = classification(embs)
_, predicts = torch.max(logists, 1)
labels_val = labels[val_nodes]
assert len(labels_val) == len(predicts)
comps = zip(labels_val, predicts.data) vali_f1 = f1_score(labels_val, predicts.cpu().data, average="micro")
print("Validation F1:", vali_f1) if vali_f1 > max_vali_f1:
max_vali_f1 = vali_f1
embs = graphSage(test_nodes)
logists = classification(embs)
_, predicts = torch.max(logists, 1)
labels_test = labels[test_nodes]
assert len(labels_test) == len(predicts)
comps = zip(labels_test, predicts.data)test_f1 = f1_score(labels_test, predicts.cpu().data, average="micro")
print("Test F1:", test_f1)for param in params:
param.requires_grad = Truetorch.save(models, './model_best_{}_ep{}_{:.4f}.torch'.format(name, cur_epoch, test_f1)) for param in params:
param.requires_grad = True return max_vali_f1def get_gnn_embeddings(gnn_model, dataCenter, ds):
"""使用GraphSage获得节点的嵌入表示"""
print('Loading embeddings from trained GraphSAGE model.')
features = np.zeros((len(getattr(dataCenter, ds+'_labels')), gnn_model.out_size))
nodes = np.arange(len(getattr(dataCenter, ds+'_labels'))).tolist()
b_sz = 500
batches = math.ceil(len(nodes) / b_sz)
embs = []
for index in range(batches):
nodes_batch = nodes[index*b_sz:(index+1)*b_sz]
embs_batch = gnn_model(nodes_batch)
assert len(embs_batch) == len(nodes_batch)
embs.append(embs_batch)
# if ((index+1)*b_sz) % 10000 == 0:
#print(f'Dealed Nodes [{(index+1)*b_sz}/{len(nodes)}]')assert len(embs) == batches
embs = torch.cat(embs, 0)
assert len(embs) == len(nodes)
print('Embeddings loaded.')
return embs.detach()def train_classification(dataCenter, graphSage, classification, ds, device, max_vali_f1, name, epochs=800):
"""训练分类器"""
print('Training Classification ...')
c_optimizer = torch.optim.SGD(classification.parameters(), lr=0.5)
# train classification, detached from the current graph
#classification.init_params()
b_sz = 50
train_nodes = getattr(dataCenter, ds+'_train')
labels = getattr(dataCenter, ds+'_labels')
features = get_gnn_embeddings(graphSage, dataCenter, ds)
for epoch in range(epochs):
train_nodes = shuffle(train_nodes)
batches = math.ceil(len(train_nodes) / b_sz)
visited_nodes = set()
for index in range(batches):
nodes_batch = train_nodes[index*b_sz:(index+1)*b_sz]
visited_nodes |= set(nodes_batch)
labels_batch = labels[nodes_batch]
embs_batch = features[nodes_batch]logists = classification(embs_batch)
loss = -torch.sum(logists[range(logists.size(0)), labels_batch], 0)
loss /= len(nodes_batch)
# print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Dealed Nodes [{}/{}] '.format(epoch+1, epochs, index, batches, loss.item(), len(visited_nodes), len(train_nodes)))loss.backward()nn.utils.clip_grad_norm_(classification.parameters(), 5)
c_optimizer.step()
c_optimizer.zero_grad()max_vali_f1 = evaluate(dataCenter, ds, graphSage, classification, device, max_vali_f1, name, epoch)
return classification, max_vali_f1def apply_model(dataCenter, ds, graphSage, classification, unsupervised_loss, b_sz, unsup_loss, device, learn_method):
test_nodes = getattr(dataCenter, ds+'_test')
val_nodes = getattr(dataCenter, ds+'_val')
train_nodes = getattr(dataCenter, ds+'_train')
labels = getattr(dataCenter, ds+'_labels') if unsup_loss == 'margin':
num_neg = 6
elif unsup_loss == 'normal':
num_neg = 100
else:
print("unsup_loss can be only 'margin' or 'normal'.")
sys.exit(1) train_nodes = shuffle(train_nodes) models = [graphSage, classification]
params = []
for model in models:
for param in model.parameters():
if param.requires_grad:
params.append(param) optimizer = torch.optim.SGD(params, lr=0.7)
optimizer.zero_grad()
for model in models:
model.zero_grad() batches = math.ceil(len(train_nodes) / b_sz) visited_nodes = set()
for index in range(batches):
nodes_batch = train_nodes[index*b_sz:(index+1)*b_sz]# extend nodes batch for unspervised learning
# no conflicts with supervised learning
nodes_batch = np.asarray(list(unsupervised_loss.extend_nodes(nodes_batch, num_neg=num_neg)))
visited_nodes |= set(nodes_batch)# get ground-truth for the nodes batch
labels_batch = labels[nodes_batch]# feed nodes batch to the graphSAGE
# returning the nodes embeddings
embs_batch = graphSage(nodes_batch)if learn_method == 'sup':
# superivsed learning
logists = classification(embs_batch)
loss_sup = -torch.sum(logists[range(logists.size(0)), labels_batch], 0)
loss_sup /= len(nodes_batch)
loss = loss_sup
elif learn_method == 'plus_unsup':
# superivsed learning
logists = classification(embs_batch)
loss_sup = -torch.sum(logists[range(logists.size(0)), labels_batch], 0)
loss_sup /= len(nodes_batch)
# unsuperivsed learning
if unsup_loss == 'margin':
loss_net = unsupervised_loss.get_loss_margin(embs_batch, nodes_batch)
elif unsup_loss == 'normal':
loss_net = unsupervised_loss.get_loss_sage(embs_batch, nodes_batch)
loss = loss_sup + loss_net
else:
if unsup_loss == 'margin':
loss_net = unsupervised_loss.get_loss_margin(embs_batch, nodes_batch)
elif unsup_loss == 'normal':
loss_net = unsupervised_loss.get_loss_sage(embs_batch, nodes_batch)
loss = loss_netprint('Step [{}/{}], Loss: {:.4f}, Dealed Nodes [{}/{}] '.format(index+1, batches, loss.item(), len(visited_nodes), len(train_nodes)))
loss.backward()
for model in models:
nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()optimizer.zero_grad()
for model in models:
model.zero_grad() return graphSage, classification
2.5 Main
file_paths = {'cora_content':'./cora.content','cora_cite':'./cora.cites'}
datacenter= DataCenter(file_paths)
datacenter.load_Dataset()
feature_data = https://www.it610.com/article/torch.FloatTensor(getattr(datacenter,'cora'+'_feats'))
label_data = https://www.it610.com/article/torch.from_numpy(getattr(datacenter,'cora'+'_labels')).long()
adj_lists = getattr(datacenter,'cora'+'_adj_lists')
random.seed(824)
np.random.seed(824)
torch.manual_seed(824)
torch.cuda.manual_seed_all(824)
learn_method = 'sup'
ds = 'cora'
epochs = 50
max_vali_f1=0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
graphSage = GraphSage(2, feature_data.size(1), 128, feature_data, getattr(datacenter, ds+'_adj_lists'), device, gcn='store_true', agg_func='MEAN')
num_labels = len(set(getattr(datacenter, ds+'_labels')))
classification = Classification(128, num_labels)
unsupervised_loss = UnsupervisedLoss(getattr(datacenter, ds+'_adj_lists'), getattr(datacenter, ds+'_train'), device)
if learn_method == 'sup':
print('GraphSage with Supervised Learning')
elif learn_method == 'plus_unsup':
print('GraphSage with Supervised Learning plus Net Unsupervised Learning')
else:
print('GraphSage with Net Unsupervised Learning')for epoch in range(epochs):
print('----------------------EPOCH %d-----------------------' % epoch)
graphSage, classification = apply_model(datacenter, ds, graphSage, classification, unsupervised_loss, 20, 'normal', device, learn_method)
if (epoch+1) % 2 == 0 and learn_method == 'unsup':
classification, max_vali_f1 = train_classification(datacenter, graphSage, classification, ds, device,max_vali_f1, 'debug')
if learn_method != 'unsup':
max_vali_f1 = evaluate(datacenter, ds, graphSage, classification, device, max_vali_f1 , 'debug', epoch)
- 输出结果如下:
文章图片
[1] Hamilton W L, Ying R, Leskovec J. Inductive representation learning on large graphs[J]. arXiv preprint arXiv:1706.02216, 2017.
[2] https://github.com/twjiang/graphSAGE-pytorch
推荐阅读
- 深度学习|Yolov5训练指南—CoCo格式数据集
- 如何在计算机中训练神经网络或随机森林算法并在以后的Android设备中进行测试()
- 实践torch.fx第一篇——基于Pytorch的模型优化量化神器
- 深度学习|【庖丁解牛】从零实现FCOS(终)(CenterSample的重要性)
- 炼丹|使用FCOS训练自己的数据
- AI|论文研读(三)(FCOS: Fully Convolutional One-Stage Object Detection之补充)
- 用 Python 做“贪吃蛇”,在线吃不饱
- 程序员|手把手教你使用 Python 制作贪吃蛇游戏
- c++|Python 什么时候会被取代()