位置: IT常识 - 正文

Pytorch+PyG实现GraphSAGE(pytorch with no grad)

编辑:rootadmin
Pytorch+PyG实现GraphSAGE 文章目录前言一、导入相关库二、加载Cora数据集三、定义GraphSAGE网络四、定义模型五、模型训练六、模型验证七、结果完整代码前言

推荐整理分享Pytorch+PyG实现GraphSAGE(pytorch with no grad),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch autograd.grad,pytorch grad,pytorch autograd.grad,pytorch gym,pytorch grad,pytorch autograd.grad,pytorch grad,pytorch vgg,内容如对您有帮助,希望把文章链接给更多的朋友!

大家好,我是阿光。

本专栏整理了《图神经网络代码实战》,内包含了不同图神经网络的相关代码实现(PyG以及自实现),理论与实践相结合,如GCN、GAT、GraphSAGE等经典图网络,每一个代码实例都附带有完整的代码。

正在更新中~ ✨

🚨 我的项目环境:

平台:Windows10语言环境:python3.7编译器:PyCharmPyTorch版本:1.11.0PyG版本:2.1.0

💥 项目专栏:【图神经网络代码实战目录】

本文我们将使用Pytorch + Pytorch Geometric来简易实现一个GraphSAGE,让新手可以理解如何PyG来搭建一个简易的图网络实例demo。

一、导入相关库Pytorch+PyG实现GraphSAGE(pytorch with no grad)

本项目我们需要结合两个库,一个是Pytorch,因为还需要按照torch的网络搭建模型进行书写,第二个是PyG,因为在torch中并没有关于图网络层的定义,所以需要torch_geometric这个库来定义一些图层。

import torchimport torch.nn.functional as Fimport torch.nn as nnimport torch_geometric.nn as pyg_nnfrom torch_geometric.datasets import Planetoid二、加载Cora数据集

本文使用的数据集是比较经典的Cora数据集,它是一个根据科学论文之间相互引用关系而构建的Graph数据集合,论文分为7类,共2708篇。

Genetic_AlgorithmsNeural_NetworksProbabilistic_MethodsReinforcement_LearningRule_LearningTheory

这个数据集是一个用于图节点分类的任务,数据集中只有一张图,这张图中含有2708个节点,10556条边,每个节点的特征维度为1433。

# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')三、定义GraphSAGE网络

这里我们就不重点介绍GraphSAGE网络了,相信大家能够掌握基本原理,本文我们使用的是PyG定义网络层,在PyG中已经定义好了SAGEConv这个层,该层采用的就是GraphSAGE机制。

对于SAGEConv的常用参数:

in_channels:每个样本的输入维度,就是每个节点的特征维度out_channels:经过注意力机制后映射成的新的维度,就是经过GAT后每个节点的维度长度normalize:是否添加自环,并且是否归一化,默认为Trueadd_self_loops:为图添加自环,是否考虑自身节点的信息bias:训练一个偏置b# 2.定义GraphSAGE网络class GraphSAGE(nn.Module): def __init__(self, num_node_features, num_classes): super(GraphSAGE, self).__init__() self.conv1 = pyg_nn.SAGEConv(num_node_features, 16) self.conv2 = pyg_nn.SAGEConv(16, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)

上面网络我们定义了两个SAGEConv层,第一层的参数的输入维度就是初始每个节点的特征维度,输出维度是16。

第二个层的输入维度为16,输出维度为分类个数,因为我们需要对每个节点进行分类,最终加上softmax操作。

四、定义模型

下面就是定义了一些模型需要的参数,像学习率、迭代次数这些超参数,然后是模型的定义以及优化器及损失函数的定义,和pytorch定义网络是一样的。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 200 # 学习轮数lr = 0.0003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 3.定义模型model = GraphSAGE(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数五、模型训练

模型训练部分也是和pytorch定义网络一样,因为都是需要经过前向传播、反向传播这些过程,对于损失、精度这些指标可以自己添加。

# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')六、模型验证

下面就是模型验证阶段,在训练时我们是只使用了训练集,测试的时候我们使用的是测试集,注意这和传统网络测试不太一样,在图像分类一些经典任务中,我们是把数据集分成了两份,分别是训练集、测试集,但是在Cora这个数据集中并没有这样,它区分训练集还是测试集使用的是掩码机制,就是定义了一个和节点长度相同纬度的数组,该数组的每个位置为True或者False,标记着是否使用该节点的数据进行训练。

# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))七、结果【EPOCH: 】1训练损失为:1.9547 训练精度为:0.1429【EPOCH: 】21训练损失为:1.8378 训练精度为:0.2143【EPOCH: 】41训练损失为:1.6961 训练精度为:0.3929【EPOCH: 】61训练损失为:1.4987 训练精度为:0.6857【EPOCH: 】81训练损失为:1.3121 训练精度为:0.7714【EPOCH: 】101训练损失为:1.1580 训练精度为:0.9143【EPOCH: 】121训练损失为:0.9903 训练精度为:0.8643【EPOCH: 】141训练损失为:0.8326 训练精度为:0.9286【EPOCH: 】161训练损失为:0.7429 训练精度为:0.9571【EPOCH: 】181训练损失为:0.6505 训练精度为:0.9571【Finished Training!】>>>Train Accuracy: 1.0000 Train Loss: 0.4065>>>Test Accuracy: 0.7060 Test Loss: 1.2712训练集测试集Accuracy1.00000.7060Loss0.40651.2712完整代码import torchimport torch.nn.functional as Fimport torch.nn as nnimport torch_geometric.nn as pyg_nnfrom torch_geometric.datasets import Planetoid# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')# 2.定义GraphSAGE网络class GraphSAGE(nn.Module): def __init__(self, num_node_features, num_classes): super(GraphSAGE, self).__init__() self.conv1 = pyg_nn.SAGEConv(num_node_features, 16) self.conv2 = pyg_nn.SAGEConv(16, num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 200 # 学习轮数lr = 0.0003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 3.定义模型model = GraphSAGE(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))
本文链接地址:https://www.jiuchutong.com/zhishi/298676.html 转载请保留说明!

上一篇:GPT-4重磅发布,它究竟厉害在哪?(gpt40)

下一篇:Typescript 全栈最值得学习的技术栈 TRPC(typescript完全解读)

  • 苹果13mini怎么连拍(苹果13mini怎么连接耳机)

    苹果13mini怎么连拍(苹果13mini怎么连接耳机)

  • 千牛不支持苹果系统吗(苹果手机怎么装不了千牛)

    千牛不支持苹果系统吗(苹果手机怎么装不了千牛)

  • 华为nova7可以隐藏应用吗(华为nova7可以隐藏桌面图标吗)

    华为nova7可以隐藏应用吗(华为nova7可以隐藏桌面图标吗)

  • 电脑连校园网为什么不弹出登录页面(电脑连校园网为什么打不开网页)

    电脑连校园网为什么不弹出登录页面(电脑连校园网为什么打不开网页)

  • 发出去的微信删除对方还能看见吗(发出的微信删除后对方还看得见吗)

    发出去的微信删除对方还能看见吗(发出的微信删除后对方还看得见吗)

  • 苹果7plus为什么发烫(苹果7plus为什么会频繁出现自动开关机)

    苹果7plus为什么发烫(苹果7plus为什么会频繁出现自动开关机)

  • 手机开空调是不是要下个什么软件(手机开空调不能调温度)

    手机开空调是不是要下个什么软件(手机开空调不能调温度)

  • 苹果id都能查出什么(苹果id能查什么)

    苹果id都能查出什么(苹果id能查什么)

  • 加湿器有风无雾是什么原因(加湿器有风无雾解析图片)

    加湿器有风无雾是什么原因(加湿器有风无雾解析图片)

  • 如何改华为手机锁屏时间(如何改华为手机下方的按键)

    如何改华为手机锁屏时间(如何改华为手机下方的按键)

  • 照片后缀aee是什么意思(图片后缀aae)

    照片后缀aee是什么意思(图片后缀aae)

  • 手机用的时间长了反应很慢怎么办(手机用的时间长了卡顿怎么办)

    手机用的时间长了反应很慢怎么办(手机用的时间长了卡顿怎么办)

  • 计算机中的媒体是什么(计算机中的媒体分为哪几类)

    计算机中的媒体是什么(计算机中的媒体分为哪几类)

  • 计算机硬件系统包括哪些(计算机硬件系统由哪几个部分组成)

    计算机硬件系统包括哪些(计算机硬件系统由哪几个部分组成)

  • 手机怎么设置字体上有拼音(手机怎么设置字体上面有拼音)

    手机怎么设置字体上有拼音(手机怎么设置字体上面有拼音)

  • 一加7T Pro怎么显示网速(一加7t pro怎样)

    一加7T Pro怎么显示网速(一加7t pro怎样)

  • vivo如何查看后台运行

    vivo如何查看后台运行

  • 中国联通hd什么意思(中国联通hd有什么用)

    中国联通hd什么意思(中国联通hd有什么用)

  • 苹果11有512G吗(iphone 11有没有512g的)

    苹果11有512G吗(iphone 11有没有512g的)

  • 公众平台修改登录邮箱方法(公众号修改登录密码怎么修改)

    公众平台修改登录邮箱方法(公众号修改登录密码怎么修改)

  • directx.exe是病毒程序吗 directx进程安全吗(directx安全吗)

    directx.exe是病毒程序吗 directx进程安全吗(directx安全吗)

  • python机器人编程——差速机器人小车的控制,控制模型、轨迹跟踪,轨迹规划、自动泊车(上)(python机器人编程控制)

    python机器人编程——差速机器人小车的控制,控制模型、轨迹跟踪,轨迹规划、自动泊车(上)(python机器人编程控制)

  • bootadm命令  管理引导配置(bootz命令)

    bootadm命令 管理引导配置(bootz命令)

  • 纳税会计的要素有
  • 营业税金及附加是什么科目
  • 稳岗返还多久能到账
  • 税控盘维护费280多久可以抵扣
  • 8.会计核算方法具体包括哪些内容?
  • 未签购销合同需不需要印花税
  • 其他流动资产对应科目
  • 社保补贴有几年
  • 已导出的申报表如何修改
  • t3用友软件怎么设置三级科目
  • 社保基数与工资不符
  • 监理费可以由施工方出吗
  • 长期挂账应收账款怎么调
  • 手工帐怎么登记
  • 预付账款转入其他非流动资产
  • 以前月度费用当期怎么入账合适?
  • 产生的信息服务有哪些
  • 履约保证金需纳什么税
  • 行政机关作出下列行为属于行政复议的范围
  • 生产企业出口退税流程怎么操作
  • 应交增值税进项税额转出借贷方向表示什么
  • 签合同交什么照片比较好
  • 案例分析两个分公司转资金怎么做账?
  • 财务物料消耗都有哪些
  • 专项维修基金和契税有什么区别
  • 普票的销项可以抵扣吗?
  • 投资利税率计算器在线计算
  • 费用报销单的日期
  • 现金支票工本费发票
  • 公司注销要交分红税吗
  • 收到苗木发票怎么做账
  • 小规模纳税人确认收入时要确认税吗
  • 有限公司股权怎么划分
  • mac开机按command+r没反应
  • 随机赠送是啥意思
  • 公司废业
  • wordpress相关文章插件
  • php 设计模式
  • 广告费发票内容是什么
  • 短视频小程序源码
  • kindeditor编辑器图片上传
  • antd:ConfigProvider+getPopupContainer解决筛选框遮挡问题(及其他浮层问题)
  • dpkg命令详解
  • 其他权益工具投资是金融资产吗
  • 本期到期债务计算公式
  • 不动产投资应该怎么做账
  • Python解释器有哪些种类
  • 加油的电子发票在哪里找
  • 旅行社小规模纳税人差额征税
  • 企业选择简易征收方案
  • 差旅费属于什么支出类型
  • 宾馆收入怎么做账
  • 未开票的收入如何确认分录
  • 汇算清缴退税分录
  • 增值税发票未认证丢失怎么办
  • 代持的股份
  • 弥补企业以前年度亏损 顺序
  • 加计扣除所得税申报表怎么填写
  • 加油卡充值发票可以抵税吗
  • 计提资产减值准备会计科目
  • 工业企业的材料
  • windows update更新卡住不动了
  • windows server 2016最大内存
  • freebsd怎么安装
  • win7系统运行慢,如何提速
  • 电脑键盘上f1到f12快捷键的功能分别是
  • windows7怎么卸载
  • win10系统如何关闭杀毒软件和防火墙
  • win1021h2版本怎么样
  • three.js gui
  • ie版本过低怎么升级win7
  • 批处理编程教程
  • 批处理的扩展名
  • javascript还有人用吗
  • js正则匹配特殊符号
  • JavaSacript中charCodeAt()方法的使用详解
  • 用js实现类的方法
  • scrollbottom用法
  • 银行扣账户维护费会计分录
  • 商贸有限公司怎么运营
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设