位置: IT常识 - 正文

Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding)

编辑:rootadmin
Pytorch实现EdgeCNN(基于PyTorch实现) 文章目录前言一、导入相关库二、加载Cora数据集三、定义EdgeCNN网络3.1 定义EdgeConv层3.1.1 特征拼接3.1.2 max聚合3.1.3 特征映射3.1.4 EdgeConv层3.2 定义EdgeCNN网络四、定义模型五、模型训练六、模型验证七、结果完整代码前言

推荐整理分享Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch encoder decoder,pytorch embedding lookup,pytorch embedding lookup,pyTorch实现多分类预测,pytorch demo,pytorch encoder decoder,pyTorch实现多分类预测,pytorch generator,内容如对您有帮助,希望把文章链接给更多的朋友!

大家好,我是阿光。

本专栏整理了《图神经网络代码实战》,内包含了不同图神经网络的相关代码实现(PyG以及自实现),理论与实践相结合,如GCN、GAT、GraphSAGE等经典图网络,每一个代码实例都附带有完整的代码。

正在更新中~ ✨

🚨 我的项目环境:

平台:Windows10语言环境:python3.7编译器:PyCharmPyTorch版本:1.11.0PyG版本:2.1.0

💥 项目专栏:【图神经网络代码实战目录】

本文我们将使用PyTorch来简易实现一个EdgeCNN,不使用PyG库,让新手可以理解如何PyTorch来搭建一个简易的图网络实例demo。

一、导入相关库

本项目是采用自己实现的EdgeCNN,并没有使用 PyG 库,原因是为了帮助新手朋友们能够对EdgeConv的原理有个更深刻的理解,如果熟悉之后可以尝试使用PyG库直接调用 EdgeConv 这个图层即可。

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch_geometric.utils import scatterfrom torch_geometric.datasets import Planetoid二、加载Cora数据集

本文使用的数据集是比较经典的Cora数据集,它是一个根据科学论文之间相互引用关系而构建的Graph数据集合,论文分为7类,共2708篇。

Genetic_AlgorithmsNeural_NetworksProbabilistic_MethodsReinforcement_LearningRule_LearningTheory

这个数据集是一个用于图节点分类的任务,数据集中只有一张图,这张图中含有2708个节点,10556条边,每个节点的特征维度为1433。

# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')三、定义EdgeCNN网络3.1 定义EdgeConv层

这里我们就不重点介绍EdgeCNN网络了,相信大家能够掌握基本原理,本文我们使用的是PyTorch定义网络层。

对于EdgeConv的常用参数:

nn:进行节点特征转换使用的 MLP网络,需要自己定义传入aggr:聚合邻居节点特征时采用的方式,默认为 max

我们在实现时也是考虑这几个常见参数

对于EdgeConv的传播公式为: xi′=∑j∈N(i)hθ(xi∣∣xj−xi)x_i'=\sum_{j\in N(i)}h_{\theta}(x_i||x_j-x_i)xi′​=j∈N(i)∑​hθ​(xi​∣∣xj​−xi​)

上式子中的 xix_ixi​ 代表中心节点特征信息, xjx_jxj​ 代表邻居节点的特征信息,对于 hθh_{\theta}hθ​ 代表每个 EdgeConv 层的可学习参数,也就是对应传入的MLP层中的可学习参数。

Pytorch实现EdgeCNN(基于PyTorch实现)(pytorch中embedding)

所以我们的任务无非就是获取这几个变量,然后进行传播计算即可

3.1.1 特征拼接

该环节实现的公式为:xi∣∣xj−xix_i||x_j-x_ixi​∣∣xj​−xi​,对于这个公式来说,我们要获得两个变量,一个是中心节点 xix_ixi​(target)的特征信息,一个是邻居节点 xjx_jxj​(source)的特征信息。

对于这两个变量的获取很容易,利用 edge_index 就可以提取出来,edge_index 中保存的是每一条边的一对起始节点与终止节点,对于起始节点可以认为就是 i,对于终止节点就可以认为是 j,然后我们就会获得两个向量,分别为 row 和 col ,这两个向量就是起始顶点和终止顶点的集合。

然后我们在根据索引进行提取特征,利用 x_i = x[row] 和 x_j = x[col] 就可以将中心节点和终止节点对应的特征获取,维度为【E,feature_size】。

然后就可以按照公式实现做差然后与中心节点的特征进行拼接,获得拼接后的特征维度为原来的2倍。

row, col = edge_index # 获取target、source节点索引 [E]x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size]x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size]x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size]

对于这里 x_i 和 x_j 以及起始节点的索引初学可能混淆,所以多多打印中间结果一步一步调试进行理解。

3.1.2 max聚合

对于 EdgeConv 的默认聚合方式为 max,其实还可以使用 mean 、sum 等排列不变函数进行聚合。

对于聚合操作就是公式中求和符号那里,只不过框架给的公式是 sum ,对于聚合我们希望做的是将中心节点的邻居特征按照指定的聚合方式进行聚合。

我们可以利用 PyG 工具库中提供的 scatter 函数进行操作,该函数可以指定聚合方式以及聚合维度等参数,使用方法就是需要传入需要聚合的 Tensor ,此外还需要传入一个 index ,指明哪些向量为同一个邻居的节点,举个例子,我们传入的 index=[0,0,0,1,1] ,这就代表第一个、第二个、第三个为同一节点的邻居,所以就会将待聚合的 Tensor 的第一个向量、第二个向量、第三个向量按照指定聚合方式进行聚合。

这里说的有点抽象,自己尝试一个简单示例就明白了。

out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size]3.1.3 特征映射

在公式中有个 hθh_{\theta}hθ​,这个就代表 MLP 做特征映射做的,对于官方给的 EdgeConv 需要我们手动传入 MLP 模型,所以本项目自实现也是按照这种方式,MLP 的操作在 EdgeConv 中并没有实现,而是利用传入的模型进行操作。

这里注意一点就是定义的 MLP 模型的输入维度应该为原始维度的2倍,因为我们在这之前进行了特征拼接操作,所以特征维度进行了加倍。

out = self.mlp(out) # 特征映射 [num_nodes, out_channels]3.1.4 EdgeConv层

接下来就可以定义EdgeConv层了,该层实现了1个函数,为 forward()

forward():这个函数定义模型的传播过程,也就是上面公式的 xi′=∑j∈N(i)hθ(xi∣∣xj−xi)x_i'=\sum_{j\in N(i)}h_{\theta}(x_i||x_j-x_i)xi′​=∑j∈N(i)​hθ​(xi​∣∣xj​−xi​)# 2.定义EdgeConv层class EdgeConv(nn.Module): def __init__(self, nn, aggr='max'): super(EdgeConv, self).__init__() self.mlp = nn # MLP网络 def forward(self, x, edge_index): row, col = edge_index # 获取target、source节点索引 [E] x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size] x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size] x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size] out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size] out = self.mlp(out) # 特征映射 [num_nodes, out_channels] return out

对于我们实现这个网络的实现效率上来讲比PyG框架内置的 EdgeConv 层稍差一点,因为我们是按照公式来一步一步利用矩阵计算得到,没有对矩阵计算以及算法进行优化,不然初学者可能看不太懂,不利于理解EdgeConv公式的传播过程,有能力的小伙伴可以看下官方源码学习一下,框架内是按照消息传递方式实现的。

3.2 定义EdgeCNN网络

上面我们已经实现好了 EdgeConv 的网络层,之后就可以调用这个层来搭建 EdgeCNN 网络。

# 3.定义EdgeConv网络class EdgeCNN(nn.Module): def __init__(self, num_node_features, num_classes): super(EdgeCNN, self).__init__() self.conv1 = EdgeConv(nn=nn.Linear(2 * num_node_features, 16), aggr='max') self.conv2 = EdgeConv(nn=nn.Linear(2 * 16, num_classes), aggr='max') def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)

上面网络我们定义了两个EdgeConv层,第一层的参数的输入维度就是初始每个节点的特征维度 * 2,输出维度是16。

第二个层的输入维度为16 * 2,输出维度为分类个数,因为我们需要对每个节点进行分类,最终加上softmax操作。

这里说明一下为什么要将输入乘以2,原因是在使用MLP进行特征转换之前,会将中心节点的特征与中心节点和邻居节点的差向量做拼接,所以得到的输出维度为节点的特征维度 * 2。

四、定义模型

下面就是定义了一些模型需要的参数,像学习率、迭代次数这些超参数,然后是模型的定义以及优化器及损失函数的定义,和pytorch定义网络是一样的。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 10 # 学习轮数lr = 0.003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 3.定义模型model = EdgeCNN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数五、模型训练

模型训练部分也是和pytorch定义网络一样,因为都是需要经过前向传播、反向传播这些过程,对于损失、精度这些指标可以自己添加。

# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')六、模型验证

下面就是模型验证阶段,在训练时我们是只使用了训练集,测试的时候我们使用的是测试集,注意这和传统网络测试不太一样,在图像分类一些经典任务中,我们是把数据集分成了两份,分别是训练集、测试集,但是在Cora这个数据集中并没有这样,它区分训练集还是测试集使用的是掩码机制,就是定义了一个和节点长度相同纬度的数组,该数组的每个位置为True或者False,标记着是否使用该节点的数据进行训练。

# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))七、结果【EPOCH: 】1训练损失为:1.9629 训练精度为:0.1214【EPOCH: 】21训练损失为:1.6709 训练精度为:0.5714【EPOCH: 】41训练损失为:1.3965 训练精度为:0.7571【EPOCH: 】61训练损失为:1.1095 训练精度为:0.8643【EPOCH: 】81训练损失为:0.9088 训练精度为:0.9286【EPOCH: 】101训练损失为:0.7454 训练精度为:0.9643【EPOCH: 】121训练损失为:0.5841 训练精度为:0.9643【EPOCH: 】141训练损失为:0.4985 训练精度为:0.9714【EPOCH: 】161训练损失为:0.3954 训练精度为:0.9714【EPOCH: 】181训练损失为:0.3339 训练精度为:0.9857【Finished Training!】>>>Train Accuracy: 1.0000 Train Loss: 0.3133>>>Test Accuracy: 0.4230 Test Loss: 1.6562训练集测试集Accuracy1.00000.4230Loss0.31331.6562完整代码import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch_geometric.utils import scatterfrom torch_geometric.datasets import Planetoid# 1.加载Cora数据集dataset = Planetoid(root='./data/Cora', name='Cora')# 2.定义EdgeConv层class EdgeConv(nn.Module): def __init__(self, nn, aggr='max'): super(EdgeConv, self).__init__() self.mlp = nn # MLP网络 def forward(self, x, edge_index): row, col = edge_index # 获取target、source节点索引 [E] x_i = x[row] # 获取target节点信息,中心节点 [E, feature_size] x_j = x[col] # 获取source节点信息,邻居节点 [E, feature_size] x_cat = torch.cat([x_i, x_j - x_i], dim=1) # 拼接特征 [E, 2 * feature_size] out = scatter(src=x_cat, index=row, dim=0, reduce='max') # max聚合操作 [num_nodes, feature_size] out = self.mlp(out) # 特征映射 [num_nodes, out_channels] return out# 3.定义EdgeConv网络class EdgeCNN(nn.Module): def __init__(self, num_node_features, num_classes): super(EdgeCNN, self).__init__() self.conv1 = EdgeConv(nn=nn.Linear(2 * num_node_features, 16), aggr='max') self.conv2 = EdgeConv(nn=nn.Linear(2 * 16, num_classes), aggr='max') def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备epochs = 200 # 学习轮数lr = 0.0003 # 学习率num_node_features = dataset.num_node_features # 每个节点的特征数num_classes = dataset.num_classes # 每个节点的类别数data = dataset[0].to(device) # Cora的一张图# 4.定义模型model = EdgeCNN(num_node_features, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器loss_function = nn.NLLLoss() # 损失函数# 训练模式model.train()for epoch in range(epochs): optimizer.zero_grad() pred = model(data) loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失 correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目 acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度 loss.backward() optimizer.step() if epoch % 20 == 0: print("【EPOCH: 】%s" % str(epoch + 1)) print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证model.eval()pred = model(data)# 训练集(使用了掩码)correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()acc_train = correct_count_train / data.train_mask.sum().item()loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()acc_test = correct_count_test / data.test_mask.sum().item()loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))
本文链接地址:https://www.jiuchutong.com/zhishi/300575.html 转载请保留说明!

上一篇:用css画一个csdn程序猿(用css画一个扇形)

下一篇:网易二面:CPU狂飙900%,该怎么处理?(网易游戏二面)

  • 发朋友圈地址怎么设置自定义(发朋友圈地址怎么设置别的城市)

    发朋友圈地址怎么设置自定义(发朋友圈地址怎么设置别的城市)

  • qq如何更改语言(qq如何更改语言设置)

    qq如何更改语言(qq如何更改语言设置)

  • 苹果11分屏多窗口的方法(苹果11屏幕分身)

    苹果11分屏多窗口的方法(苹果11屏幕分身)

  • 耳机左耳比右耳声音大(耳机左耳比右耳耗电快)

    耳机左耳比右耳声音大(耳机左耳比右耳耗电快)

  • 为什么卖家怕淘宝介入(为什么卖家怕差评)

    为什么卖家怕淘宝介入(为什么卖家怕差评)

  • iphone11摔了一下会变卡吗(iphone11摔了一下屏幕划不动但是有显示)

    iphone11摔了一下会变卡吗(iphone11摔了一下屏幕划不动但是有显示)

  • 电路板上ant是什么意思(电路板上的a)

    电路板上ant是什么意思(电路板上的a)

  • 固态硬盘256g科学分区(固态硬盘256g的一般多少钱)

    固态硬盘256g科学分区(固态硬盘256g的一般多少钱)

  • 调制调节器651处理办法(调制调节器651处理办法win10)

    调制调节器651处理办法(调制调节器651处理办法win10)

  • 外地的手机号码可以改成本地的吗(外地的手机号码可以在本地办理业务吗)

    外地的手机号码可以改成本地的吗(外地的手机号码可以在本地办理业务吗)

  • 号码加入黑名单对方听到的是什么(号码加入黑名单为什么还能打进来)

    号码加入黑名单对方听到的是什么(号码加入黑名单为什么还能打进来)

  • 下标怎么打快捷键(快捷键打下标)

    下标怎么打快捷键(快捷键打下标)

  • 字里面带横线怎么整(字带横线怎么打出来)

    字里面带横线怎么整(字带横线怎么打出来)

  • 苹果11是双卡双待双通吗(苹果11是双卡还是单卡)

    苹果11是双卡双待双通吗(苹果11是双卡还是单卡)

  • 小米有小爱同学oppo有什么(小米有小爱同学vivo有什么)

    小米有小爱同学oppo有什么(小米有小爱同学vivo有什么)

  • 路由器双千兆什么意思(路由器双千兆和千兆有什么区别)

    路由器双千兆什么意思(路由器双千兆和千兆有什么区别)

  • iqoo怎么手动打开液冷(iqoo手机怎么叫)

    iqoo怎么手动打开液冷(iqoo手机怎么叫)

  • 删除wps表格某些信息(wps表格如何删除部分内容)

    删除wps表格某些信息(wps表格如何删除部分内容)

  • 有赞买家怎么删除订单(怎么删除有赞订单购买成功的记录)

    有赞买家怎么删除订单(怎么删除有赞订单购买成功的记录)

  • 小米max2快充怎么设置(小米max2手机快充变慢充)

    小米max2快充怎么设置(小米max2手机快充变慢充)

  • 分屏键盘怎么变小(分屏键盘怎么变小vivo)

    分屏键盘怎么变小(分屏键盘怎么变小vivo)

  • 医院微信退款多久到账(医院微信退款多久到账户)

    医院微信退款多久到账(医院微信退款多久到账户)

  • 咸鱼交易评价怎么删除(咸鱼评价怎么处理)

    咸鱼交易评价怎么删除(咸鱼评价怎么处理)

  • 两个蓝牙音箱怎么互联(两个蓝牙音箱怎样连在一起播放)

    两个蓝牙音箱怎么互联(两个蓝牙音箱怎样连在一起播放)

  • 打电话闪光灯怎么设置(打电话闪光灯怎么关闭苹果)

    打电话闪光灯怎么设置(打电话闪光灯怎么关闭苹果)

  • DVDRegionFree.exe进程是安全程序吗 DVDRegionFree进程查询(dvd.rom)

    DVDRegionFree.exe进程是安全程序吗 DVDRegionFree进程查询(dvd.rom)

  • Spring Boot 3.0系列【19】核心特性篇之自定义Starter启动器(spring boot 2.3.0)

    Spring Boot 3.0系列【19】核心特性篇之自定义Starter启动器(spring boot 2.3.0)

  • 房地产企业税率为5销售水泥怎么算
  • 资产组可收回金额包含商誉吗
  • 公司基本户里的钱有利息吗
  • 开票和收到的款金额不一样怎么办?
  • 个税免税收入怎么进行更正申报
  • 生产型企业出口退税计算公式
  • 收到预付款的发票怎么写摘要
  • 出差回来报销差旅费,补付现金的会计分录
  • 员工自己领取社保卡需要带什么资料
  • 注资的设备出售怎么处理
  • 预收账款计入应纳税所得额
  • 建筑业统一发票可以抵扣吗
  • 营业账簿是什么意思
  • 无偿受让股权的股东对发起股东没有出资承担责任
  • 税务局每年都会查我公司虚开发票
  • 员工个人抬头的医院发票可以入账吗
  • 增值税专用发票可以开电子发票吗
  • 哪些票据可以冲销
  • 对外投资公司经营范围
  • 建筑公司一般纳税人增值税税率
  • 收据换发票的会计分录
  • 装饰公司发票怎么
  • 小微企业需要税务登记吗
  • 关于linux说法错误的是
  • 临时文件夹移动到c盘根目录下windows7
  • 出口退免税的基本政策包括
  • iphonexs如何强制关机重启
  • 会计中记账凭证的名词解释
  • 融资中的未确认利息
  • 公司收到款后怎么做账
  • 高温补贴需要缴纳社会保险费吗
  • 好奇地看着我
  • 如何导入并使用数据库
  • mac vue搭建本地环境
  • php 时间格式转换
  • 完美解决索尼电视arc无声音
  • yaf框架优缺点
  • flask 教程
  • nodejs hook
  • php file函数
  • 企业的所有分类
  • 一件代发退货如何处理
  • php decbin
  • 门诊收费票据能重新打印吗
  • 印花税申报完成如何缴纳
  • 固定资产的主要风险和关键控制点有哪些?
  • 债权投资利息收入调增还是调减
  • 小规模纳税人纳多少税
  • 当月利息发票未开可以先入账吗
  • 劳务报酬是自行缴纳吗
  • excel内账报表
  • 增值税期末留抵退税
  • 增值税税负率是多少
  • 应付账款的主要成本包括
  • 收购企业如何做账务处理
  • 固定资产累计折旧会计科目
  • 支付土地租金计入什么科目里面
  • 债务类科目和债权类科目
  • 企业坏账准备提取的方法和提取的比例由国家统一规定
  • 营业收入包括哪几项收入
  • 跨年做进项税额转出
  • mysql命令导入
  • sql server临时表创建语句
  • WIN10系统优化技巧
  • 丢失acui16.dll
  • win10的打开方式
  • 如何设置win10自动登录
  • win10缺少文件怎么办
  • 微软推送win11
  • win7系统盘u盘
  • linux防病毒措施
  • 投影变换的使用方法
  • mac vscode opengl
  • 基于web的学生成绩管理系统毕业论文
  • JavaScript中的case
  • js中的常用方法
  • 比较常见的电子商务模式
  • java 视频教程
  • 甘肃税务局电子发票怎么开
  • 网上交了购置税你要打印出来吗
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设