位置: IT常识 - 正文

Pytorch+PyG实现GCN(图卷积网络)(pytorch go)

编辑:rootadmin
Pytorch+PyG实现GCN(图卷积网络) 文章目录前言一、导入相关库二、加载Cora数据集三、定义GCN网络四、定义模型五、模型训练六、模型验证七、结果完整代码前言

推荐整理分享Pytorch+PyG实现GCN(图卷积网络)(pytorch go),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:python gcc,pytorch no_grad,pytorch no_grad,pytorch的gru,pytorch gym,pytorch gym,pytorch vgg,pytorch vgg,内容如对您有帮助,希望把文章链接给更多的朋友!

大家好,我是阿光。

本专栏整理了《图神经网络代码实战》,内包含了不同图神经网络的相关代码实现(PyG以及自实现),理论与实践相结合,如GCN、GAT、GraphSAGE等经典图网络,每一个代码实例都附带有完整的代码。

正在更新中~ ✨

🚨 我的项目环境:

平台:Windows10语言环境:python3.7编译器:PyCharmPyTorch版本:1.11.0PyG版本:2.1.0

💥 项目专栏:【图神经网络代码实战目录】

Pytorch+PyG实现GCN(图卷积网络)(pytorch go)

本文我们将使用Pytorch + Pytorch Geometric来简易实现一个GCN(图卷积网络),让新手可以理解如何PyG来搭建一个简易的图网络实例demo。

一、导入相关库

本项目我们需要结合两个库,一个是Pytorch,因为还需要按照torch的网络搭建模型进行书写,第二个是PyG,因为在torch中并没有关于图网络层的定义,所以需要torch_geometric这个库来定义一些图层。

import torchimport torch.nn.functional as Fimport torch.nn as nnimport torch_geometric.nn as pyg_nnfrom torch_geometric.datasets import Planetoid二、加载Cora数据集

本文使用的数据集是比较经典的Cora数据集,它是一个根据科学论文之间相互引用关系而构建的Graph数据集合,论文分为7类,共2708篇。

Genetic_AlgorithmsNeural_NetworksProbabilistic_MethodsReinforcement_LearningRule_LearningTheory

这个数据集是一个用于图节点分类的任务,数据集中只有一张图,这张图中含有2708个节点,10556条边,每个节点的特征维度为1433。

# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')三、定义GCN网络

这里我们就不重点介绍GCN网络了,相信大家能够掌握基本原理,本文我们使用的是PyG定义网络层,在PyG中已经定义好了GCNConv这个层,该层采用的就是GCN机制。

对于GCNConv的常用参数:

in_channels:每个样本的输入维度,就是每个节点的特征维度out_channels:经过注意力机制后映射成的新的维度,就是经过GAT后每个节点的维度长度normalize:是否添加自环,并且是否归一化,默认为Trueadd_self_loops:为图添加自环,是否考虑自身节点的信息bias:训练一个偏置b# 2.定义GCNConv网络class GCN(nn.Module): def __init__(self, num_node_features, num_classes): super(GCN, self).__init__() self.conv1 = pyg_nn.GCNConv(num_node_features, 16) self.conv2 = pyg_nn.GCNConv(16, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)

上面网络我们定义了两个GCNConv层,第一层的参数的输入维度就是初始每个节点的特征维度,输出维度是16。

第二个层的输入维度为16,输出维度为分类个数,因为我们需要对每个节点进行分类,最终加上softmax操作。

四、定义模型

下面就是定义了一些模型需要的参数,像学习率、迭代次数这些超参数,然后是模型的定义以及优化器及损失函数的定义,和pytorch定义网络是一样的。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 10 # 学习轮数lr = 0.003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 3.定义模型model = GCN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数五、模型训练

模型训练部分也是和pytorch定义网络一样,因为都是需要经过前向传播、反向传播这些过程,对于损失、精度这些指标可以自己添加。

# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')六、模型验证

下面就是模型验证阶段,在训练时我们是只使用了训练集,测试的时候我们使用的是测试集,注意这和传统网络测试不太一样,在图像分类一些经典任务中,我们是把数据集分成了两份,分别是训练集、测试集,但是在Cora这个数据集中并没有这样,它区分训练集还是测试集使用的是掩码机制,就是定义了一个和节点长度相同纬度的数组,该数组的每个位置为True或者False,标记着是否使用该节点的数据进行训练。

# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))七、结果【EPOCH: 】1训练损失为:1.9594 训练精度为:0.1571【EPOCH: 】21训练损失为:1.8681 训练精度为:0.3286【EPOCH: 】41训练损失为:1.7647 训练精度为:0.5000【EPOCH: 】61训练损失为:1.6587 训练精度为:0.5571【EPOCH: 】81训练损失为:1.5258 训练精度为:0.6714【EPOCH: 】101训练损失为:1.4334 训练精度为:0.7143【EPOCH: 】121训练损失为:1.3361 训练精度为:0.7714【EPOCH: 】141训练损失为:1.2310 训练精度为:0.8357【EPOCH: 】161训练损失为:1.1443 训练精度为:0.8571【EPOCH: 】181训练损失为:1.0962 训练精度为:0.8714【Finished Training!】>>>Train Accuracy: 0.9357 Train Loss: 0.9735>>>Test Accuracy: 0.7200 Test Loss: 1.3561训练集测试集Accuracy0.93570.7200Loss0.97351.3561完整代码import torchimport torch.nn.functional as Fimport torch.nn as nnimport torch_geometric.nn as pyg_nnfrom torch_geometric.datasets import Planetoid# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')# 2.定义GCNConv网络class GCN(nn.Module): def __init__(self, num_node_features, num_classes): super(GCN, self).__init__() self.conv1 = pyg_nn.GCNConv(num_node_features, 16) self.conv2 = pyg_nn.GCNConv(16, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 200 # 学习轮数lr = 0.0003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 3.定义模型model = GCN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))
本文链接地址:https://www.jiuchutong.com/zhishi/297458.html 转载请保留说明!

上一篇:chrome插件开发时跨域问题解决方案(chrome插件开发语言)

下一篇:Nginx静态资源部署(nginx搭建静态资源服务器)

  • 企业怎样玩转微博营销的技巧(怎样玩转企业微信)

    企业怎样玩转微博营销的技巧(怎样玩转企业微信)

  • realmegt大师探索版怎么开nfc

    realmegt大师探索版怎么开nfc

  • 北京健康宝如何申请通勤(北京健康宝如何解除弹窗)

    北京健康宝如何申请通勤(北京健康宝如何解除弹窗)

  • 微信号被永久封了还想用这个号码怎么办(微信号被永久封禁还能解开吗)

    微信号被永久封了还想用这个号码怎么办(微信号被永久封禁还能解开吗)

  • 电池显示维修什么情况(电池显示维修什么原因)

    电池显示维修什么情况(电池显示维修什么原因)

  • 高级语言翻译程序的实现途径有哪两种啊(高级语言翻译程序两种方法)

    高级语言翻译程序的实现途径有哪两种啊(高级语言翻译程序两种方法)

  • 卖家拉黑买家能投诉吗(卖家拉黑买家能收到钱吗)

    卖家拉黑买家能投诉吗(卖家拉黑买家能收到钱吗)

  • UDP首部的长度是多少比特(udp首部长度是固定的吗)

    UDP首部的长度是多少比特(udp首部长度是固定的吗)

  • 微信新设备登录没有好友验证(微信新设备登录限制怎么解除)

    微信新设备登录没有好友验证(微信新设备登录限制怎么解除)

  • 笔记本电脑卡了怎么关机(笔记本电脑卡了怎么结束程序)

    笔记本电脑卡了怎么关机(笔记本电脑卡了怎么结束程序)

  • wifi测速快但上网很慢(wifi测速很快,使用起来却很慢是什么原因?)

    wifi测速快但上网很慢(wifi测速很快,使用起来却很慢是什么原因?)

  • 华为nova闪光灯怎么开(华为nova闪光灯怎么)

    华为nova闪光灯怎么开(华为nova闪光灯怎么)

  • m2硬盘和ssd硬盘区别(m2硬盘和ssd硬盘可以用在一起吗)

    m2硬盘和ssd硬盘区别(m2硬盘和ssd硬盘可以用在一起吗)

  • 联通大王卡什么意思(联通大王卡什么套餐划算)

    联通大王卡什么意思(联通大王卡什么套餐划算)

  • 华为手机振动怎么调大(华为手机振动怎么关?)

    华为手机振动怎么调大(华为手机振动怎么关?)

  • 拼多多怎么弄闪电退货(拼多多闪光灯怎么关闭)

    拼多多怎么弄闪电退货(拼多多闪光灯怎么关闭)

  • 怎么看手机卡上的号码(怎么看手机卡上存的号码)

    怎么看手机卡上的号码(怎么看手机卡上存的号码)

  • 快手上买东西不给退货怎么办(快手买东西不显示订单)

    快手上买东西不给退货怎么办(快手买东西不显示订单)

  • 鸿蒙系统基于什么(鸿蒙系统基于什么架构)

    鸿蒙系统基于什么(鸿蒙系统基于什么架构)

  • 青桔单车扫码开不了锁(青桔单车怎么扫)

    青桔单车扫码开不了锁(青桔单车怎么扫)

  • 台式电脑连不上宽带(台式电脑连不上网络是什么原因)

    台式电脑连不上宽带(台式电脑连不上网络是什么原因)

  • 苹果手机怎么设置后台运行(苹果手机怎么设置动态壁纸)

    苹果手机怎么设置后台运行(苹果手机怎么设置动态壁纸)

  • 阿里tv怎么投屏(阿里tv投屏没字幕)

    阿里tv怎么投屏(阿里tv投屏没字幕)

  • ijkplayer解码流程源码解读(ijk解码是什么意思)

    ijkplayer解码流程源码解读(ijk解码是什么意思)

  • 核定征收的个体户可以开专票吗
  • 小规模纳税人增值税减免账务处理
  • 所得税残疾人工资加计扣除
  • 房开企业预售阶段预交的税费
  • 关联方交易的会计处理方法
  • 餐费补贴要交个人所得税吗
  • 银行结息计入什么费用
  • 一般纳税人工会经费返还政策
  • 印花税什么情况可以退
  • 滞纳金为千分之二从何年开始实施
  • 如何规范填写费用表格
  • 增值税税率调整时间17变16
  • 可抵扣增值税的发票
  • 股东收取了公司的货款
  • 将债务转为资本会引起负债总额发生变动吗
  • 固定资产抵扣多少年
  • 财务报表层次重大错报风险增大了认定层次
  • 企业安装监控费用怎么做账
  • 小型微利企业所得税减免政策
  • 业务招待费怎么调整应纳税所得额
  • 公司接受安全罚款的账务处理
  • 印花税可以根据企业流水申报吗
  • 应付职工薪酬明细表怎么填写
  • 鸿蒙系统获取电脑文件
  • 购买理财产品收到的利息分录
  • window10为什么没有本地用户和组
  • 汇算清缴审计报告收费标准
  • PHP:session_module_name()的用法_Session函数
  • 增值税价外费用怎么算
  • 以前年度未入账固定资产账务处理
  • vue设置宽度
  • 房地产开发企业应该具备哪些条件
  • 持有至到期投资账务处理
  • php判断文件是否存在的函数
  • 统计不同类型的数量
  • 出口抵减内销产品应纳税额怎么结转
  • 帝国cms wordpress
  • 小规模季报附加税怎么报
  • 小规模纳税人在什么情况下会成为一般纳税人
  • 加计抵减可以补提本年的税吗
  • 帝国cms灵动标签 PHP变量文章ID加减1
  • 手续费及佣金支出核算
  • 什么时候计提所得税费用会计分录
  • 中华人民共和国企业所得税年度纳税申报表
  • mysql 索性
  • mysqldump定时备份
  • 小规模费用发票可以抵扣增值税吗
  • 增值税期末留抵退税
  • 社保费阶段性减免政策到什么时候
  • 应付票据转应付账款分录
  • 所得税审核一般需要多久
  • 老板想提取销售怎么办
  • 企业补提以前年度折旧政策依据怎么写
  • 金税盘可以用热点吗
  • 普通发票和增值税发票的区别图片
  • 展位费按多少税率
  • 五金领用流程
  • 医疗器械销售能一年挣一百万么
  • 对公网银回单可以导出吗
  • MySQL删除重复数据只保留一条
  • mysql服务无效
  • gentoo安装教程2021
  • windows10预览版是什么
  • ubuntu安装后怎么启动
  • cmd命令怎么运行
  • fdb是什么文件
  • 进程 电脑
  • 小马kms激活工具
  • opengl
  • jquery的deferred
  • shader要学多久
  • node.js中的http.request.end方法使用说明
  • 基于JAVASCRIPT实现的可视化工具是
  • javascript学习指南
  • error: Error parsing XML: unbound prefix
  • swift method swizzling
  • 国家税务总局开票系统怎么开票
  • 华为领导班子成员名单
  • 进口小麦关税税率是多少
  • 房地产预缴土增值税
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设