位置: IT常识 - 正文

Pytorch+PyG实现GCN(图卷积网络)(pytorch go)

编辑:rootadmin
Pytorch+PyG实现GCN(图卷积网络) 文章目录前言一、导入相关库二、加载Cora数据集三、定义GCN网络四、定义模型五、模型训练六、模型验证七、结果完整代码前言

推荐整理分享Pytorch+PyG实现GCN(图卷积网络)(pytorch go),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:python gcc,pytorch no_grad,pytorch no_grad,pytorch的gru,pytorch gym,pytorch gym,pytorch vgg,pytorch vgg,内容如对您有帮助,希望把文章链接给更多的朋友!

大家好,我是阿光。

本专栏整理了《图神经网络代码实战》,内包含了不同图神经网络的相关代码实现(PyG以及自实现),理论与实践相结合,如GCN、GAT、GraphSAGE等经典图网络,每一个代码实例都附带有完整的代码。

正在更新中~ ✨

🚨 我的项目环境:

平台:Windows10语言环境:python3.7编译器:PyCharmPyTorch版本:1.11.0PyG版本:2.1.0

💥 项目专栏:【图神经网络代码实战目录】

Pytorch+PyG实现GCN(图卷积网络)(pytorch go)

本文我们将使用Pytorch + Pytorch Geometric来简易实现一个GCN(图卷积网络),让新手可以理解如何PyG来搭建一个简易的图网络实例demo。

一、导入相关库

本项目我们需要结合两个库,一个是Pytorch,因为还需要按照torch的网络搭建模型进行书写,第二个是PyG,因为在torch中并没有关于图网络层的定义,所以需要torch_geometric这个库来定义一些图层。

import torchimport torch.nn.functional as Fimport torch.nn as nnimport torch_geometric.nn as pyg_nnfrom torch_geometric.datasets import Planetoid二、加载Cora数据集

本文使用的数据集是比较经典的Cora数据集,它是一个根据科学论文之间相互引用关系而构建的Graph数据集合,论文分为7类,共2708篇。

Genetic_AlgorithmsNeural_NetworksProbabilistic_MethodsReinforcement_LearningRule_LearningTheory

这个数据集是一个用于图节点分类的任务,数据集中只有一张图,这张图中含有2708个节点,10556条边,每个节点的特征维度为1433。

# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')三、定义GCN网络

这里我们就不重点介绍GCN网络了,相信大家能够掌握基本原理,本文我们使用的是PyG定义网络层,在PyG中已经定义好了GCNConv这个层,该层采用的就是GCN机制。

对于GCNConv的常用参数:

in_channels:每个样本的输入维度,就是每个节点的特征维度out_channels:经过注意力机制后映射成的新的维度,就是经过GAT后每个节点的维度长度normalize:是否添加自环,并且是否归一化,默认为Trueadd_self_loops:为图添加自环,是否考虑自身节点的信息bias:训练一个偏置b# 2.定义GCNConv网络class GCN(nn.Module): def __init__(self, num_node_features, num_classes): super(GCN, self).__init__() self.conv1 = pyg_nn.GCNConv(num_node_features, 16) self.conv2 = pyg_nn.GCNConv(16, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)

上面网络我们定义了两个GCNConv层,第一层的参数的输入维度就是初始每个节点的特征维度,输出维度是16。

第二个层的输入维度为16,输出维度为分类个数,因为我们需要对每个节点进行分类,最终加上softmax操作。

四、定义模型

下面就是定义了一些模型需要的参数,像学习率、迭代次数这些超参数,然后是模型的定义以及优化器及损失函数的定义,和pytorch定义网络是一样的。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 10 # 学习轮数lr = 0.003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 3.定义模型model = GCN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数五、模型训练

模型训练部分也是和pytorch定义网络一样,因为都是需要经过前向传播、反向传播这些过程,对于损失、精度这些指标可以自己添加。

# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')六、模型验证

下面就是模型验证阶段,在训练时我们是只使用了训练集,测试的时候我们使用的是测试集,注意这和传统网络测试不太一样,在图像分类一些经典任务中,我们是把数据集分成了两份,分别是训练集、测试集,但是在Cora这个数据集中并没有这样,它区分训练集还是测试集使用的是掩码机制,就是定义了一个和节点长度相同纬度的数组,该数组的每个位置为True或者False,标记着是否使用该节点的数据进行训练。

# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))七、结果【EPOCH: 】1训练损失为:1.9594 训练精度为:0.1571【EPOCH: 】21训练损失为:1.8681 训练精度为:0.3286【EPOCH: 】41训练损失为:1.7647 训练精度为:0.5000【EPOCH: 】61训练损失为:1.6587 训练精度为:0.5571【EPOCH: 】81训练损失为:1.5258 训练精度为:0.6714【EPOCH: 】101训练损失为:1.4334 训练精度为:0.7143【EPOCH: 】121训练损失为:1.3361 训练精度为:0.7714【EPOCH: 】141训练损失为:1.2310 训练精度为:0.8357【EPOCH: 】161训练损失为:1.1443 训练精度为:0.8571【EPOCH: 】181训练损失为:1.0962 训练精度为:0.8714【Finished Training!】>>>Train Accuracy: 0.9357 Train Loss: 0.9735>>>Test Accuracy: 0.7200 Test Loss: 1.3561训练集测试集Accuracy0.93570.7200Loss0.97351.3561完整代码import torchimport torch.nn.functional as Fimport torch.nn as nnimport torch_geometric.nn as pyg_nnfrom torch_geometric.datasets import Planetoid# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')# 2.定义GCNConv网络class GCN(nn.Module): def __init__(self, num_node_features, num_classes): super(GCN, self).__init__() self.conv1 = pyg_nn.GCNConv(num_node_features, 16) self.conv2 = pyg_nn.GCNConv(16, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 200 # 学习轮数lr = 0.0003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 3.定义模型model = GCN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))
本文链接地址:https://www.jiuchutong.com/zhishi/297458.html 转载请保留说明!

上一篇:chrome插件开发时跨域问题解决方案(chrome插件开发语言)

下一篇:Nginx静态资源部署(nginx搭建静态资源服务器)

  • 12306网页版爱心模式是什么(ff14 跨服聊天)

    12306网页版爱心模式是什么(ff14 跨服聊天)

  • 荣耀50黑白色怎么调彩色(荣耀50黑白色怎样调回彩色)

    荣耀50黑白色怎么调彩色(荣耀50黑白色怎样调回彩色)

  • 开心消消乐怎么求助好友过关(开心消消乐怎么让好友帮忙过关)

    开心消消乐怎么求助好友过关(开心消消乐怎么让好友帮忙过关)

  • 华为nova5ipro打字时怎么关声音(华为nova5i打字声音怎么关)

    华为nova5ipro打字时怎么关声音(华为nova5i打字声音怎么关)

  • 网易云音乐账号能两个手机同时用吗(网易云音乐账号怎么找回)

    网易云音乐账号能两个手机同时用吗(网易云音乐账号怎么找回)

  • 抖音小黄车怎么开通需要什么条件(抖音小黄车怎么挂自己的产品卖货)

    抖音小黄车怎么开通需要什么条件(抖音小黄车怎么挂自己的产品卖货)

  • 为什么微信校验不通过(微信8.0 校验失败)

    为什么微信校验不通过(微信8.0 校验失败)

  • 酷喵是什么(优酷的酷喵是什么)

    酷喵是什么(优酷的酷喵是什么)

  • 魅族ba611是啥型号(魅族型号m681q)

    魅族ba611是啥型号(魅族型号m681q)

  • 计算器上的ac键是表示什么(计算器上的ac键是改错键吗)

    计算器上的ac键是表示什么(计算器上的ac键是改错键吗)

  • 手机上怎么交党费(手机上怎么交党员费)

    手机上怎么交党费(手机上怎么交党员费)

  • wps怎么删除一列的内容(wps怎么删除一列单元格的内容)

    wps怎么删除一列的内容(wps怎么删除一列单元格的内容)

  • 苹果7怎么录屏(苹果7怎么录屏功能怎么使用)

    苹果7怎么录屏(苹果7怎么录屏功能怎么使用)

  • 快手直播不能分享朋友圈怎么回事(快手不能分屏怎么办)

    快手直播不能分享朋友圈怎么回事(快手不能分屏怎么办)

  • vivo系统应用在哪(vivo手机应用系统在哪)

    vivo系统应用在哪(vivo手机应用系统在哪)

  •  电话被拉黑能发信息吗(电话被拉黑能发短信吗)

    电话被拉黑能发信息吗(电话被拉黑能发短信吗)

  • 百度hi怎么修改密码(百度怎么修改用户昵称)

    百度hi怎么修改密码(百度怎么修改用户昵称)

  • 闲鱼关闭交易会影响信誉吗(咸鱼关闭交易会怎么样)

    闲鱼关闭交易会影响信誉吗(咸鱼关闭交易会怎么样)

  • 怎么取消ca证书(取消ca证书认证)

    怎么取消ca证书(取消ca证书认证)

  • 哈利法塔湖中的迪拜喷泉,迪拜哈利法塔 (© Eli Asenova/Getty Images)(哈利法塔里面有什么)

    哈利法塔湖中的迪拜喷泉,迪拜哈利法塔 (© Eli Asenova/Getty Images)(哈利法塔里面有什么)

  • 新必应申请与使用教程:让你体验人工智能搜索引擎(新必应申请使用资格)

    新必应申请与使用教程:让你体验人工智能搜索引擎(新必应申请使用资格)

  • MSN中国男人频道采集规则For DedeCMS v5.5(中国男人百度百科)

    MSN中国男人频道采集规则For DedeCMS v5.5(中国男人百度百科)

  • python可变数据类型和不可变数据类型的区别(Python可变数据类型和不可变数据类型)

    python可变数据类型和不可变数据类型的区别(Python可变数据类型和不可变数据类型)

  • 应交税费和应交增值税
  • 个税累计免征额
  • 职工教育经费能结转几年
  • 销售材料应确认的损益是什么意思
  • 取得租金收入的会计分录
  • 有发票无明细能报销吗
  • 生产车间的房屋要交税吗
  • 施工企业已完工程成本如何结转
  • 营改增后取得施工作业收入需要交哪些税?
  • 出售无形资产属于让渡资产使用权吗
  • 个人独资的企业性质是什么
  • 营改增后附加税费入应交税费还是营业税金及附加
  • 税务改革方向
  • 自己的公司钱能自己用吗
  • 什么情况下企业不能辞退员工
  • 累计折旧属于什么
  • 没有原始凭证可以审计吗
  • 企业收入确认的依据是什么
  • PHP:mb_regex_set_options()的用法_mbstring函数
  • php实现基数排序函数
  • 个人独资企业公账转私账
  • 误餐费怎么入账
  • php的file函数
  • sgbhp.exe - sgbhp是什么进程 有什么用
  • 理财产品利息税
  • 企业租地建厂流程
  • php框架yii
  • 预收账款和应收账款的账务处理
  • 竣工结算审计费用在线计算器
  • 购买增值税税控系统如何抵扣增值税
  • java定时器怎么用
  • opencv如何显示图片
  • 应收账款减值损失计入
  • 佣金怎么收税
  • 跨年发票可以作为税前扣除的时限
  • 计划资产产生的股利
  • 什么科目需要结转到本年利润
  • 销售车位怎么找客户
  • mysql如何做优化
  • 营业税和营业税额一样吗
  • 阶段性减免社保费政策期限延长
  • 办理契税所需要的证件
  • 进项抵扣和销项抵扣
  • 冲销未开票收入还需要申报吗
  • 国有资产划转实施方案
  • 对外投资需要股东会决议吗
  • 高温补贴发放管理制度
  • 分期收款定义
  • 报关单不在海关信息中
  • 公司增资怎么办理手续
  • 最新商业会计科目做账
  • 旅行社财务会计工作内容
  • 固定资产折旧完了怎么做账
  • 收缩后对数据库有影响吗
  • sql server如何打开mdf格式文件
  • sql server分页查询sql语句
  • windows server 2008 r2离线激活
  • win7系统软件安装就闪退怎么办
  • fedora安装xorg
  • win7怎样关闭u盘保护功能
  • linux查看hz
  • 怎么使用linux命令
  • WIN10系统安装.net报错0x80072f8F
  • iptables防火墙规则
  • 用简洁的语言推荐一本书
  • python 先序遍历
  • vxlan配置实例详解
  • javascript有哪些常用的属性和方法
  • jquery prevall
  • dos批命令
  • pcs可以使用什么在任何地方以各种速率与网络保持联络
  • js的实现原理
  • jqueryif判断
  • android反编译软件
  • jquery设置图片大小
  • jqueryapi手机版
  • 在androidstudio中,如何改变图片的位置
  • 税务系统跨区调动
  • 宁波离哪个国家比较近
  • 什么是美国注册商标
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设