位置: IT常识 - 正文

Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding)

编辑:rootadmin
Pytorch实现EdgeCNN(基于PyTorch实现) 文章目录前言一、导入相关库二、加载Cora数据集三、定义EdgeCNN网络3.1 定义EdgeConv层3.1.1 特征拼接3.1.2 max聚合3.1.3 特征映射3.1.4 EdgeConv层3.2 定义EdgeCNN网络四、定义模型五、模型训练六、模型验证七、结果完整代码前言

推荐整理分享Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch encoder decoder,pytorch embedding lookup,pytorch embedding lookup,pyTorch实现多分类预测,pytorch demo,pytorch encoder decoder,pyTorch实现多分类预测,pytorch generator,内容如对您有帮助,希望把文章链接给更多的朋友!

大家好,我是阿光。

本专栏整理了《图神经网络代码实战》,内包含了不同图神经网络的相关代码实现(PyG以及自实现),理论与实践相结合,如GCN、GAT、GraphSAGE等经典图网络,每一个代码实例都附带有完整的代码。

正在更新中~ ✨

🚨 我的项目环境:

平台:Windows10语言环境:python3.7编译器:PyCharmPyTorch版本:1.11.0PyG版本:2.1.0

💥 项目专栏:【图神经网络代码实战目录】

本文我们将使用PyTorch来简易实现一个EdgeCNN,不使用PyG库,让新手可以理解如何PyTorch来搭建一个简易的图网络实例demo。

一、导入相关库

本项目是采用自己实现的EdgeCNN,并没有使用 PyG 库,原因是为了帮助新手朋友们能够对EdgeConv的原理有个更深刻的理解,如果熟悉之后可以尝试使用PyG库直接调用 EdgeConv 这个图层即可。

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch_geometric.utils import scatterfrom torch_geometric.datasets import Planetoid二、加载Cora数据集

本文使用的数据集是比较经典的Cora数据集,它是一个根据科学论文之间相互引用关系而构建的Graph数据集合,论文分为7类,共2708篇。

Genetic_AlgorithmsNeural_NetworksProbabilistic_MethodsReinforcement_LearningRule_LearningTheory

这个数据集是一个用于图节点分类的任务,数据集中只有一张图,这张图中含有2708个节点,10556条边,每个节点的特征维度为1433。

# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')三、定义EdgeCNN网络3.1 定义EdgeConv层

这里我们就不重点介绍EdgeCNN网络了,相信大家能够掌握基本原理,本文我们使用的是PyTorch定义网络层。

对于EdgeConv的常用参数:

nn:进行节点特征转换使用的 MLP网络,需要自己定义传入aggr:聚合邻居节点特征时采用的方式,默认为 max

我们在实现时也是考虑这几个常见参数

对于EdgeConv的传播公式为: xi′=∑j∈N(i)hθ(xi∣∣xj−xi)x_i'=\sum_{j\in N(i)}h_{\theta}(x_i||x_j-x_i)xi′​=j∈N(i)∑​hθ​(xi​∣∣xj​−xi​)

上式子中的 xix_ixi​ 代表中心节点特征信息, xjx_jxj​ 代表邻居节点的特征信息,对于 hθh_{\theta}hθ​ 代表每个 EdgeConv 层的可学习参数,也就是对应传入的MLP层中的可学习参数。

Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding)

所以我们的任务无非就是获取这几个变量,然后进行传播计算即可

3.1.1 特征拼接

该环节实现的公式为:xi∣∣xj−xix_i||x_j-x_ixi​∣∣xj​−xi​,对于这个公式来说,我们要获得两个变量,一个是中心节点 xix_ixi​(target)的特征信息,一个是邻居节点 xjx_jxj​(source)的特征信息。

对于这两个变量的获取很容易,利用 edge_index 就可以提取出来,edge_index 中保存的是每一条边的一对起始节点与终止节点,对于起始节点可以认为就是 i,对于终止节点就可以认为是 j,然后我们就会获得两个向量,分别为 row 和 col ,这两个向量就是起始顶点和终止顶点的集合。

然后我们在根据索引进行提取特征,利用 x_i = x[row] 和 x_j = x[col] 就可以将中心节点和终止节点对应的特征获取,维度为【E,feature_size】。

然后就可以按照公式实现做差然后与中心节点的特征进行拼接,获得拼接后的特征维度为原来的2倍。

row, col = edge_index # 获取target、source节点索引 [E]x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size]x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size]x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size]

对于这里 x_i 和 x_j 以及起始节点的索引初学可能混淆,所以多多打印中间结果一步一步调试进行理解。

3.1.2 max聚合

对于 EdgeConv 的默认聚合方式为 max,其实还可以使用 mean 、sum 等排列不变函数进行聚合。

对于聚合操作就是公式中求和符号那里,只不过框架给的公式是 sum ,对于聚合我们希望做的是将中心节点的邻居特征按照指定的聚合方式进行聚合。

我们可以利用 PyG 工具库中提供的 scatter 函数进行操作,该函数可以指定聚合方式以及聚合维度等参数,使用方法就是需要传入需要聚合的 Tensor ,此外还需要传入一个 index ,指明哪些向量为同一个邻居的节点,举个例子,我们传入的 index=[0,0,0,1,1] ,这就代表第一个、第二个、第三个为同一节点的邻居,所以就会将待聚合的 Tensor 的第一个向量、第二个向量、第三个向量按照指定聚合方式进行聚合。

这里说的有点抽象,自己尝试一个简单示例就明白了。

out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size]3.1.3 特征映射

在公式中有个 hθh_{\theta}hθ​,这个就代表 MLP 做特征映射做的,对于官方给的 EdgeConv 需要我们手动传入 MLP 模型,所以本项目自实现也是按照这种方式,MLP 的操作在 EdgeConv 中并没有实现,而是利用传入的模型进行操作。

这里注意一点就是定义的 MLP 模型的输入维度应该为原始维度的2倍,因为我们在这之前进行了特征拼接操作,所以特征维度进行了加倍。

out = self.mlp(out) # 特征映射 [num_nodes, out_channels]3.1.4 EdgeConv层

接下来就可以定义EdgeConv层了,该层实现了1个函数,为 forward()

forward():这个函数定义模型的传播过程,也就是上面公式的 xi′=∑j∈N(i)hθ(xi∣∣xj−xi)x_i'=\sum_{j\in N(i)}h_{\theta}(x_i||x_j-x_i)xi′​=∑j∈N(i)​hθ​(xi​∣∣xj​−xi​)# 2.定义EdgeConv层class EdgeConv(nn.Module): def __init__(self, nn, aggr='max'): super(EdgeConv, self).__init__() self.mlp = nn # MLP网络 def forward(self, x, edge_index): row, col = edge_index # 获取target、source节点索引 [E] x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size] x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size] x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size] out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size] out = self.mlp(out) # 特征映射 [num_nodes, out_channels] return out

对于我们实现这个网络的实现效率上来讲比PyG框架内置的 EdgeConv 层稍差一点,因为我们是按照公式来一步一步利用矩阵计算得到,没有对矩阵计算以及算法进行优化,不然初学者可能看不太懂,不利于理解EdgeConv公式的传播过程,有能力的小伙伴可以看下官方源码学习一下,框架内是按照消息传递方式实现的。

3.2 定义EdgeCNN网络

上面我们已经实现好了 EdgeConv 的网络层,之后就可以调用这个层来搭建 EdgeCNN 网络。

# 3.定义EdgeConv网络class EdgeCNN(nn.Module): def __init__(self, num_node_features, num_classes): super(EdgeCNN, self).__init__() self.conv1 = EdgeConv(nn=nn.Linear(2 * num_node_features, 16), aggr='max') self.conv2 = EdgeConv(nn=nn.Linear(2 * 16, num_classes), aggr='max') def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)

上面网络我们定义了两个EdgeConv层,第一层的参数的输入维度就是初始每个节点的特征维度 * 2,输出维度是16。

第二个层的输入维度为16 * 2,输出维度为分类个数,因为我们需要对每个节点进行分类,最终加上softmax操作。

这里说明一下为什么要将输入乘以2,原因是在使用MLP进行特征转换之前,会将中心节点的特征与中心节点和邻居节点的差向量做拼接,所以得到的输出维度为节点的特征维度 * 2。

四、定义模型

下面就是定义了一些模型需要的参数,像学习率、迭代次数这些超参数,然后是模型的定义以及优化器及损失函数的定义,和pytorch定义网络是一样的。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 10 # 学习轮数lr = 0.003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 3.定义模型model = EdgeCNN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数五、模型训练

模型训练部分也是和pytorch定义网络一样,因为都是需要经过前向传播、反向传播这些过程,对于损失、精度这些指标可以自己添加。

# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')六、模型验证

下面就是模型验证阶段,在训练时我们是只使用了训练集,测试的时候我们使用的是测试集,注意这和传统网络测试不太一样,在图像分类一些经典任务中,我们是把数据集分成了两份,分别是训练集、测试集,但是在Cora这个数据集中并没有这样,它区分训练集还是测试集使用的是掩码机制,就是定义了一个和节点长度相同纬度的数组,该数组的每个位置为True或者False,标记着是否使用该节点的数据进行训练。

# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))七、结果【EPOCH: 】1训练损失为:1.9629 训练精度为:0.1214【EPOCH: 】21训练损失为:1.6709 训练精度为:0.5714【EPOCH: 】41训练损失为:1.3965 训练精度为:0.7571【EPOCH: 】61训练损失为:1.1095 训练精度为:0.8643【EPOCH: 】81训练损失为:0.9088 训练精度为:0.9286【EPOCH: 】101训练损失为:0.7454 训练精度为:0.9643【EPOCH: 】121训练损失为:0.5841 训练精度为:0.9643【EPOCH: 】141训练损失为:0.4985 训练精度为:0.9714【EPOCH: 】161训练损失为:0.3954 训练精度为:0.9714【EPOCH: 】181训练损失为:0.3339 训练精度为:0.9857【Finished Training!】>>>Train Accuracy: 1.0000 Train Loss: 0.3133>>>Test Accuracy: 0.4230 Test Loss: 1.6562训练集测试集Accuracy1.00000.4230Loss0.31331.6562完整代码import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch_geometric.utils import scatterfrom torch_geometric.datasets import Planetoid# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')# 2.定义EdgeConv层class EdgeConv(nn.Module): def __init__(self, nn, aggr='max'): super(EdgeConv, self).__init__() self.mlp = nn # MLP网络 def forward(self, x, edge_index): row, col = edge_index # 获取target、source节点索引 [E] x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size] x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size] x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size] out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size] out = self.mlp(out) # 特征映射 [num_nodes, out_channels] return out# 3.定义EdgeConv网络class EdgeCNN(nn.Module): def __init__(self, num_node_features, num_classes): super(EdgeCNN, self).__init__() self.conv1 = EdgeConv(nn=nn.Linear(2 * num_node_features, 16), aggr='max') self.conv2 = EdgeConv(nn=nn.Linear(2 * 16, num_classes), aggr='max') def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 200 # 学习轮数lr = 0.0003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 4.定义模型model = EdgeCNN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))
本文链接地址:https://www.jiuchutong.com/zhishi/300575.html 转载请保留说明!

上一篇:用css画一个csdn程序猿(用css画一个扇形)

下一篇:网易二面:CPU狂飙900%,该怎么处理?(网易游戏二面)

  • 业主不交物业费需要需要承担的责任是什么

    业主不交物业费需要需要承担的责任是什么

  • 华为手机重置后怎么恢复数据(华为手机重置后还能恢复数据吗)

    华为手机重置后怎么恢复数据(华为手机重置后还能恢复数据吗)

  • 怎么关闭电脑自动更新win10(怎么关闭电脑自动更新系统)

    怎么关闭电脑自动更新win10(怎么关闭电脑自动更新系统)

  • 苹果手机健康码快捷指令怎么设置(苹果手机上的健康码怎么弄)

    苹果手机健康码快捷指令怎么设置(苹果手机上的健康码怎么弄)

  • windows10重置此电脑(Windows10重置此电脑C盘内容会清空吗)

    windows10重置此电脑(Windows10重置此电脑C盘内容会清空吗)

  • 滴滴抢不到单是怎么回事(滴滴快车抢不到单)

    滴滴抢不到单是怎么回事(滴滴快车抢不到单)

  • 打印机能传真吗(打印机也有传真功能吗)

    打印机能传真吗(打印机也有传真功能吗)

  • 荣耀8x录屏在哪(荣耀8x录屏在哪里打开)

    荣耀8x录屏在哪(荣耀8x录屏在哪里打开)

  • 红外摄像头原理

    红外摄像头原理

  • 为什么家里wifi别人都能用,自己的用不了(为什么家里wifi突然显示不可使用)

    为什么家里wifi别人都能用,自己的用不了(为什么家里wifi突然显示不可使用)

  • 微信视频没有声音是哪里关掉了(微信视频没有声音什么原因,如何修复)

    微信视频没有声音是哪里关掉了(微信视频没有声音什么原因,如何修复)

  • 操作系统管理的计算机系统资源包括(操作系统管理的软硬件资源有哪些)

    操作系统管理的计算机系统资源包括(操作系统管理的软硬件资源有哪些)

  • 苹果手机升级系统的优缺点(苹果手机升级系统后怎么恢复旧系统)

    苹果手机升级系统的优缺点(苹果手机升级系统后怎么恢复旧系统)

  • 微信钱包在哪(企业微信钱包在哪)

    微信钱包在哪(企业微信钱包在哪)

  • word文件打不开怎么办(word文档打不开显示内容有误)

    word文件打不开怎么办(word文档打不开显示内容有误)

  • 淘宝历史订单保留多久(淘宝历史订单保存在哪里)

    淘宝历史订单保留多久(淘宝历史订单保存在哪里)

  • 恢复出厂设置找回相册(恢复出厂设置找回电话号码)

    恢复出厂设置找回相册(恢复出厂设置找回电话号码)

  • 手机号加入黑名单怎么撤销(手机号加入黑名单会怎么样)

    手机号加入黑名单怎么撤销(手机号加入黑名单会怎么样)

  • 羊肚菌的功效(图文)(羊肚菌的功效和价格)

    羊肚菌的功效(图文)(羊肚菌的功效和价格)

  • 【vue2】近期bug收集与整理02(vue-bus)

    【vue2】近期bug收集与整理02(vue-bus)

  • element级联选择器选择获得完整数组(element级联选择器动态获取数据)

    element级联选择器选择获得完整数组(element级联选择器动态获取数据)

  • Go设计模式学习准备——下载bilibili合集视频(设计模式golang)

    Go设计模式学习准备——下载bilibili合集视频(设计模式golang)

  • 北京增值税发票网上申领流程
  • 借款合同印花税最新政策2023年
  • 对方开给我的专票遗失了,让我上传发票
  • 个人写的收据要留身份证复印件吗
  • 公司退款给客户怎么写
  • 租借车辆发生事故后的保险理赔问题
  • 无形资产加计扣除最新政策
  • 林木育种的意义和作用
  • 新会计准则要求
  • 退休返钱怎么算的
  • 软件的维修性要求
  • 小规模企业差额征收税率
  • 售楼部购买空调计入哪个科目
  • 先抵押 后租赁
  • 先收钱后开票怎么做分录
  • 工资中的话费补助是什么
  • 核定征收个体户怎么报税
  • 工商年检填写数据填错了会罚款吗
  • ios路由设计
  • win11打开图片
  • 农产品成本法计算抵扣
  • 工资薪金与劳务报酬的区别有哪些
  • 在线网速测试需要付费吗
  • 上年费用未计提
  • 什么是坏账,坏账的核算方法有哪些
  • 公会经费缴费单位应于每月
  • 运行安装程序时发生错误
  • win10的电源设置
  • php foreach二维数组
  • 哪些货物可以享受减免税政策
  • 所得税返还计入什么科目
  • php中文出现乱码
  • php 通信
  • pycharm vue
  • Pytorch深度学习实战3-6:详解网络骨架模块nn.Module(附实例)
  • vue3的ref,reactive的使用和原理解析
  • python机器人编程控制
  • laravel 分页 api
  • 基本数据结构包括哪些
  • 公司食堂吃饭没钱怎么办
  • 数据库平移
  • mysql 索引 key
  • 发票金额跟实际转账金额不一样该怎么办
  • 企业增值税的征收方式
  • 固定资产处置净收入转入什么账户核算
  • 个税汇算清缴是退税吗
  • 一个人可以做多个担保人吗
  • 收到员工社保
  • 汽车折旧年限与折旧率
  • 建筑业发票可以抵扣制造业进项
  • 购入固定资产中的增值税
  • 成本核算的内容有哪几个方面
  • 报关单位分为几种类型?其业务范围有何不同?
  • 设备信息windows6.1
  • vista启用aero
  • centos怎么连接远程服务器
  • scanfile.exe
  • ias.exe是什么程序
  • rtlrack.exe - rtlrack是什么进程 有什么用
  • hyper-v win98
  • win7禁用administrator
  • 获取android id
  • cocos code ide 1.0.0 RC0 使用教程
  • cocos2dx怎么用
  • angular实战
  • android get
  • python入门笔记
  • android设计模式总结
  • 使用vue开发手机app
  • jquery 延迟对象
  • android framework 框架层功能梳理
  • js鼠标拖动窗口的做法
  • js咋用
  • 甘肃增值税发票查验平台官网
  • 国税电子版
  • 上海公积金快速提取
  • 个人无偿捐赠增值税
  • 乌鲁木齐市公立幼儿园有哪些
  • 大连市国家税务网
  • 个体工商户税收起征点是多少
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设