位置: IT常识 - 正文

Pytorch+PyG实现GCN(图卷积网络)(pytorch go)

编辑:rootadmin
Pytorch+PyG实现GCN(图卷积网络) 文章目录前言一、导入相关库二、加载Cora数据集三、定义GCN网络四、定义模型五、模型训练六、模型验证七、结果完整代码前言

推荐整理分享Pytorch+PyG实现GCN(图卷积网络)(pytorch go),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:python gcc,pytorch no_grad,pytorch no_grad,pytorch的gru,pytorch gym,pytorch gym,pytorch vgg,pytorch vgg,内容如对您有帮助,希望把文章链接给更多的朋友!

大家好,我是阿光。

本专栏整理了《图神经网络代码实战》,内包含了不同图神经网络的相关代码实现(PyG以及自实现),理论与实践相结合,如GCN、GAT、GraphSAGE等经典图网络,每一个代码实例都附带有完整的代码。

正在更新中~ ✨

🚨 我的项目环境:

平台:Windows10语言环境:python3.7编译器:PyCharmPyTorch版本:1.11.0PyG版本:2.1.0

💥 项目专栏:【图神经网络代码实战目录】

Pytorch+PyG实现GCN(图卷积网络)(pytorch go)

本文我们将使用Pytorch + Pytorch Geometric来简易实现一个GCN(图卷积网络),让新手可以理解如何PyG来搭建一个简易的图网络实例demo。

一、导入相关库

本项目我们需要结合两个库,一个是Pytorch,因为还需要按照torch的网络搭建模型进行书写,第二个是PyG,因为在torch中并没有关于图网络层的定义,所以需要torch_geometric这个库来定义一些图层。

import torchimport torch.nn.functional as Fimport torch.nn as nnimport torch_geometric.nn as pyg_nnfrom torch_geometric.datasets import Planetoid二、加载Cora数据集

本文使用的数据集是比较经典的Cora数据集,它是一个根据科学论文之间相互引用关系而构建的Graph数据集合,论文分为7类,共2708篇。

Genetic_AlgorithmsNeural_NetworksProbabilistic_MethodsReinforcement_LearningRule_LearningTheory

这个数据集是一个用于图节点分类的任务,数据集中只有一张图,这张图中含有2708个节点,10556条边,每个节点的特征维度为1433。

# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')三、定义GCN网络

这里我们就不重点介绍GCN网络了,相信大家能够掌握基本原理,本文我们使用的是PyG定义网络层,在PyG中已经定义好了GCNConv这个层,该层采用的就是GCN机制。

对于GCNConv的常用参数:

in_channels:每个样本的输入维度,就是每个节点的特征维度out_channels:经过注意力机制后映射成的新的维度,就是经过GAT后每个节点的维度长度normalize:是否添加自环,并且是否归一化,默认为Trueadd_self_loops:为图添加自环,是否考虑自身节点的信息bias:训练一个偏置b# 2.定义GCNConv网络class GCN(nn.Module): def __init__(self, num_node_features, num_classes): super(GCN, self).__init__() self.conv1 = pyg_nn.GCNConv(num_node_features, 16) self.conv2 = pyg_nn.GCNConv(16, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)

上面网络我们定义了两个GCNConv层,第一层的参数的输入维度就是初始每个节点的特征维度,输出维度是16。

第二个层的输入维度为16,输出维度为分类个数,因为我们需要对每个节点进行分类,最终加上softmax操作。

四、定义模型

下面就是定义了一些模型需要的参数,像学习率、迭代次数这些超参数,然后是模型的定义以及优化器及损失函数的定义,和pytorch定义网络是一样的。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 10 # 学习轮数lr = 0.003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 3.定义模型model = GCN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数五、模型训练

模型训练部分也是和pytorch定义网络一样,因为都是需要经过前向传播、反向传播这些过程,对于损失、精度这些指标可以自己添加。

# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')六、模型验证

下面就是模型验证阶段,在训练时我们是只使用了训练集,测试的时候我们使用的是测试集,注意这和传统网络测试不太一样,在图像分类一些经典任务中,我们是把数据集分成了两份,分别是训练集、测试集,但是在Cora这个数据集中并没有这样,它区分训练集还是测试集使用的是掩码机制,就是定义了一个和节点长度相同纬度的数组,该数组的每个位置为True或者False,标记着是否使用该节点的数据进行训练。

# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))七、结果【EPOCH: 】1训练损失为:1.9594 训练精度为:0.1571【EPOCH: 】21训练损失为:1.8681 训练精度为:0.3286【EPOCH: 】41训练损失为:1.7647 训练精度为:0.5000【EPOCH: 】61训练损失为:1.6587 训练精度为:0.5571【EPOCH: 】81训练损失为:1.5258 训练精度为:0.6714【EPOCH: 】101训练损失为:1.4334 训练精度为:0.7143【EPOCH: 】121训练损失为:1.3361 训练精度为:0.7714【EPOCH: 】141训练损失为:1.2310 训练精度为:0.8357【EPOCH: 】161训练损失为:1.1443 训练精度为:0.8571【EPOCH: 】181训练损失为:1.0962 训练精度为:0.8714【Finished Training!】>>>Train Accuracy: 0.9357 Train Loss: 0.9735>>>Test Accuracy: 0.7200 Test Loss: 1.3561训练集测试集Accuracy0.93570.7200Loss0.97351.3561完整代码import torchimport torch.nn.functional as Fimport torch.nn as nnimport torch_geometric.nn as pyg_nnfrom torch_geometric.datasets import Planetoid# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')# 2.定义GCNConv网络class GCN(nn.Module): def __init__(self, num_node_features, num_classes): super(GCN, self).__init__() self.conv1 = pyg_nn.GCNConv(num_node_features, 16) self.conv2 = pyg_nn.GCNConv(16, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 200 # 学习轮数lr = 0.0003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 3.定义模型model = GCN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))
本文链接地址:https://www.jiuchutong.com/zhishi/297458.html 转载请保留说明!

上一篇:chrome插件开发时跨域问题解决方案(chrome插件开发语言)

下一篇:Nginx静态资源部署(nginx搭建静态资源服务器)

  • 央视频怎么看回看(央视频怎么看回放卫视)

    央视频怎么看回看(央视频怎么看回放卫视)

  • 蓝牙耳机如何充电(蓝牙耳机如何充电显示)

    蓝牙耳机如何充电(蓝牙耳机如何充电显示)

  • ipx4级防水是指什么(ip4x防水等级是什么)

    ipx4级防水是指什么(ip4x防水等级是什么)

  • word中分页符怎么弄(word中分页符怎么显示)

    word中分页符怎么弄(word中分页符怎么显示)

  • 微软powerbi免费吗(powerbi.microsoft.com)

    微软powerbi免费吗(powerbi.microsoft.com)

  • 主板上的sata1234有区别吗(主板上的satapower接口)

    主板上的sata1234有区别吗(主板上的satapower接口)

  • 微信创建群聊没发消息别人知道吗(创建的微信群聊不见了怎么办)

    微信创建群聊没发消息别人知道吗(创建的微信群聊不见了怎么办)

  • 两块固态硬盘怎么一起用(两块固态硬盘怎么设置主从盘)

    两块固态硬盘怎么一起用(两块固态硬盘怎么设置主从盘)

  • 企鹅电竞为什么黑屏(企鹅电竞为什么会停运)

    企鹅电竞为什么黑屏(企鹅电竞为什么会停运)

  • powerpoint的主要功能是(powerpoint的主要应用)

    powerpoint的主要功能是(powerpoint的主要应用)

  • 惠普403d硒鼓型号(惠普m403硒鼓)

    惠普403d硒鼓型号(惠普m403硒鼓)

  • 快手点红心怎么收费(快手点红心怎么批量取消)

    快手点红心怎么收费(快手点红心怎么批量取消)

  • ps怎么把人p白(用ps把人变白)

    ps怎么把人p白(用ps把人变白)

  • 手机耳机插电脑上能说话吗(手机耳机插电脑声音小怎么解决)

    手机耳机插电脑上能说话吗(手机耳机插电脑声音小怎么解决)

  • 快手怎么搜索视频(快手怎么搜索视频主人)

    快手怎么搜索视频(快手怎么搜索视频主人)

  • 闲鱼卖家胜利多久到款(闲鱼卖家胜利钱什么时候到账)

    闲鱼卖家胜利多久到款(闲鱼卖家胜利钱什么时候到账)

  • 网络卡顿怎么处理(网络卡顿怎么处理无线)

    网络卡顿怎么处理(网络卡顿怎么处理无线)

  • 苹果x滚动截屏怎么用(苹果滚动截屏怎么弄的)

    苹果x滚动截屏怎么用(苹果滚动截屏怎么弄的)

  • realmex是啥手机(realmex手机怎么样值得买吗)

    realmex是啥手机(realmex手机怎么样值得买吗)

  • iphonex接电话没声音(iphonex接电话没声音,打电话可以听到声音)

    iphonex接电话没声音(iphonex接电话没声音,打电话可以听到声音)

  • 快手有赞订单怎么查询(快手赞订单怎么找)

    快手有赞订单怎么查询(快手赞订单怎么找)

  • 液晶显示器有辐射吗(液晶显示器辐射大不大)

    液晶显示器有辐射吗(液晶显示器辐射大不大)

  • 查找我的iphone怎么关闭 查找我的iphone强制关闭方(查找我的iphone怎么添加设备)

    查找我的iphone怎么关闭 查找我的iphone强制关闭方(查找我的iphone怎么添加设备)

  • 详解Linux系统中字符串搜索命令ngrep的用法(linux的sh)

    详解Linux系统中字符串搜索命令ngrep的用法(linux的sh)

  • 不动产用于集体福利能否抵扣
  • 离职补偿金怎么做账
  • 劳务派遣增值税怎么算
  • 货币资金项目应根据账户的期末余额合计填列
  • 房地产企业收到预收款如何纳税
  • 房屋租赁需要交税吗?
  • 房屋估价入账需要计算什么税款?
  • 典当行借贷属于民间借贷吗
  • 负数发票跨月怎么重开
  • 物业公司支付出的费用
  • 不动产发票怎么填写
  • 个人所得税0申报逾期
  • 预缴所得税如何做账
  • 企业法人和股份的关系
  • 小规模纳税人年度不超过500万
  • 企业所得税应纳税所得额包括什么
  • 交叉持股的合并财务报表
  • 亏损的递延所得税怎么理解
  • 交易性金融资产的账务处理
  • 存货折扣怎样做账
  • 农业合作社出售农产品怎么计税
  • 零售业收入
  • 怎么光驱重装系统
  • linux命令“ln file1 file2”的含义是
  • macbookpro常见问题
  • 退回多缴的所得税怎么算
  • 新会计准则计入管理费用的税费
  • 基础知识讲解
  • 资产减值对应科目
  • github ci/cd
  • thinkphp 分页
  • df -th命令
  • 折扣返利的账务处理
  • 织梦内容页模板修改
  • 公司购买空调计入什么费用
  • 印花税的税率变动
  • 是不是所有的发票都是一样的
  • 怎么找回丢失的华为手机
  • 预缴增值税的账务处理
  • 主营业务成本账户属于什么账户
  • 房地产公司工程部岗位职责
  • 应收票据背书转让分录
  • 预付账款主要是什么
  • 残保金工资总额是按计提还是发放
  • 242104 税控盘
  • 金税三期网络设置
  • 数据库连接说明
  • sql判断字符串是否为日期
  • mac mysql密码
  • win7 bug
  • 正版vista一键升级win7
  • mac app store打开一片空白
  • 在Linux系统中安装虚拟window
  • Linux配置防火墙端口
  • 检测你的vps是不是真的
  • windows10磁盘
  • 如何打开音量控制器
  • mac电池不能被识别吗
  • myfastupdate.exe - myfastupdate是什么进程文件 有什么用
  • win8.1安装过程
  • win7网速很慢
  • kb4592449-windows安全每月质量汇总
  • TestOpenGL
  • 快速掌握日语词汇
  • linux服务器硬件配置要求
  • django框架mvt
  • perl脚本调试方法
  • linux命令scp和sftp详细介绍
  • cmd的tree指令
  • python 技巧总结
  • javascript例题
  • JavaScript indexOf方法入门实例(计算指定字符在字符串中首次出现的位置)
  • 安卓activity和fragment的区别
  • 15个值得开发人是谁
  • 开票软件如何升级系统
  • 农产品核定扣除办法38号公告
  • 房产税纳税义务终止
  • 湖南国家电子税务局企业所得税申报进不去
  • 临沭公交车多久一班
  • 苏州封闭式高中
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设