找回密码
 会员注册
查看: 27|回复: 0

GCN-图卷积神经网络算法简单实现(含python代码)

[复制链接]

1

主题

0

回帖

4

积分

新手上路

积分
4
发表于 2024-9-3 15:03:04 | 显示全部楼层 |阅读模式
本文是就实现GCN算法模型进行的代码介绍,上一篇文章是GCN算法的原理和模型介绍。代码中用到的Cora数据集:链接:https://pan.baidu.com/s/1SbqIOtysKqHKZ7C50DM_eA 提取码:pfny 文章目录目的一、数据集介绍二、实现过程讲解三、代码实现和结果分析1.导入包2.数据准备¶3. 图卷积层定义4.GCN图卷积神经网络模型定义5. 模型训练5.1超参数定义,包含学习率、正则化系数等。5.2定义模型:5.3定义训练和测试函数,进行训练6.可视化目的本次实验的目的是将论文分类,通过模型训练,利用已经分好类的训练集,将论文通过GCN算法分为7类。一、数据集介绍数据集我选用的是GCN常用的Cora数据集,实验的目标就是通过对构造出来的两层GCN模型进行训练,实现对数据集样本节点的分类Cora数据集下载地址:https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz个人不建议用python的dgl包中的Cora数据,总是报错。Cora数据集由关于机器学习方面的论文组成。这些论文分为以下七个类别之一:1.基于案例2.遗传算法3.神经网络4.概率方法5.强化学习6.规则学习7.理论这些论文都是经过筛选的,在最终的数据集中,每篇论文引用或被至少一篇其他论文引用。整个语料库中有2708篇论文。在词干堵塞和去除词尾后,只剩下1433个唯一的单词。文档频率小于10的所有单词都被删除。即Cora数据集包含2708个顶点,5429条边,每个顶点包含1433个特征,共有7个类别。并且Cora已经把训练集和测试集的数据都划分好了,直接按照文件名读取数据即可,如文件ind.cora.x=>训练实例的特征向量;ind.cora.y=>训练实例的标签,独热编码ind.cora.tx=>测试实例的特征向量;ind.cora.ty=>测试实例的标签,独热编码二、实现过程讲解结合我最后做的代码实现,给大家先举一个引文网络的简单实例,方便大家了解处理过程。其中每个节点代表一篇研究论文,同时边代表的是引用关系。我们在这里有一个预处理步骤。在这里我们不使用原始论文作为特征,而是将论文转换成向量(通过使用NLP嵌入,例如tf-idf)。假设我们使用average()函数(实际上GCN内部的传递函数肯定不是平均值,这里只是方便理解)。我们将对所有的节点进行同样的获取特征向量的操作。最后,我们将这些计算得到的平均值输入到神经网络中。让我们考虑下绿色节点。首先,我们得到它的所有邻居的特征值,包括自身节点,接着取平均值。最后通过神经网络返回一个结果向量并将此作为最终结果。请注意,在GCN中,我们仅仅使用一个全连接层。在这个例子中,我们得到2维向量作为输出(全连接层的2个节点)。全连接网络的作用就是对上一层得到的向量做乘法,最终降低其维度,然后输入到softmax层中得到对应的每个类别的得分。在实际操作中,我们肯定是使用比average函数更复杂的聚合函数,也就是上面讲的那个传播函数。我们还可以将更多的层叠加在一起,以获得更深的GCN。其中每一层的输出会被视为下一层的输入。2层GCN的例子:第一层的输出是第二层的输入。那么两层的GCN就可以在降维的同时,通过层间传播的公式获取到二阶邻居节点的特征: 在节点分类问题中,实际上在输入的邻接矩阵和每个节点的特征中,既包含了节点间的联系情况,也包含了节点自身的特征。通过GCN的卷积层就可以实现降维,想要聚成几类就降成几维。三、代码实现和结果分析1.导入包importitertoolsimportosimportos.pathasospimportpickleimporturllibfromcollectionsimportnamedtupleimportwarningswarnings.filterwarnings("ignore")importnumpyasnpimportscipy.sparseasspimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFimporttorch.nn.initasinitimporttorch.optimasoptimimportmatplotlib.pyplotasplt%matplotlibinline2.数据准备¶Data=namedtuple('Data',['x','y','adjacency','train_mask','val_mask','test_mask'])deftensor_from_numpy(x,device):returntorch.from_numpy(x).to(device)classCoraDa
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 会员注册

本版积分规则

QQ|手机版|心飞设计-版权所有:微度网络信息技术服务中心 ( 鲁ICP备17032091号-12 )|网站地图

GMT+8, 2024-12-27 13:59 , Processed in 0.487973 second(s), 26 queries .

Powered by Discuz! X3.5

© 2001-2024 Discuz! Team.

快速回复 返回顶部 返回列表