位置: IT常识 - 正文

【图神经网络实战】深入浅出地学习图神经网络GNN(上)(图神经网络gat)

编辑:rootadmin
【图神经网络实战】深入浅出地学习图神经网络GNN(上) 文章目录一、图神经网络应用领域1.1 芯片设计1.2 场景分析与问题推理1.3 推荐系统1.4 欺诈检测与风控相关1.5 知识图谱1.6 道路交通的流量预测1.7 自动驾驶(无人机等场景)1.8 化学,医疗等场景1.9 物理模型相关二、图神经网络基本知识2.1 图基本模块定义2.2 图神经网络要做的事情2.3 邻接矩阵的定义2.3.1 图数据的邻接矩阵2.3.2 文本数据的邻接矩阵2.4 GNN中的常见任务2.4.1 Graph级别任务2.4.2 Node与Edge级别任务2.5 消息传递计算方法2.5.1 优化邻接矩阵2.5.2 点的特征重构2.6 多层GNN的作用GNN输出特征的用处三、GCN详解3.1 GCN基本模型概述3.1.1 卷积 vs 图卷积3.1.2 图中常见任务3.1.3 如果获取特征3.1.4 半监督学习3.2 图卷积的基本计算方法3.2.1 GCN基本思想3.2.2 GCN层数3.2.3 图中基本组成3.2.4 特征计算方法3.3 邻接矩阵的变换3.4 GCN变换原理解读3.5 GCN传播公式四、PyTorch Geometric 库的基本使用4.1 PyTorch Geometric 的安装4.2 数据集与邻接矩阵格式4.2.1 数据集介绍4.2.2 数据探索4.2.3 使用networkx进行可视化展示4.2.4 GCN模型搭建4.2.5 使用搭建好的GCN模型五、文献引用数据集分类案例实战(基于点的任务)5.1 数据集介绍5.2 数据探索5.3 试试传统MLP的效果5.4 再看看GCN的效果六、构建自己的图数据集七、基于图神经网络的电商购买预测实例7.1 数据集介绍

推荐整理分享【图神经网络实战】深入浅出地学习图神经网络GNN(上)(图神经网络gat),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:图神经网络工作原理,图 神经网络,图神经网络 知乎,图神经网络2021,图神经网络模型,图神经网络教程,图神经网络教程,图神经网络2021,内容如对您有帮助,希望把文章链接给更多的朋友!

本文为学习产物,学习链接(如有侵权,请告知删除): 人工智能【图神经网络实战】教程,让你一天就学会深入浅出图神经网络GNN,从入门到精通!

一、图神经网络应用领域1.1 芯片设计

芯片的设计比较耗费人力和物力,如果可以通过AI算法自动设计芯片,则可以大大提高芯片制造的效率,降低芯片制造的成本

1.2 场景分析与问题推理

例如剧本杀中的推理,警匪片中嫌疑人的图推理等

1.3 推荐系统

例如,刷抖音,经常看英雄联盟的游戏视频,那么说明你对游戏比较感兴趣,系统会根据网络图结构推荐更多和英雄联盟相关的内容给你

1.4 欺诈检测与风控相关

贷款软件,读取用户的通讯录信息和app使用情况,从而测评用户的还款能力,然后决定用户的借款额度

1.5 知识图谱

智能客服

1.6 道路交通的流量预测

预测道路上每条边的流量

1.7 自动驾驶(无人机等场景)

1.8 化学,医疗等场景

利用AI对化学结构进行分析,预测

1.9 物理模型相关

根据分子结构进行相关分析

二、图神经网络基本知识2.1 图基本模块定义

V:点,每个点都有自己的特征向量(特征举例:邻居点数量、一阶二阶相似度) E:边,每个边都有自己的特征向量(特征举例:边的权重值、边的定义) U:整个图,每个图都有自己的特征向量(特征举例:节点数量、图直径)

2.2 图神经网络要做的事情为每个节点整合特征向量,根据其对节点做分类或者回归为每条边整合特征向量,根据其对边做分类或者回归为每张图整合特征向量,根据其对图做分类或者回归 2.3 邻接矩阵的定义2.3.1 图数据的邻接矩阵

2.3.2 文本数据的邻接矩阵

2.4 GNN中的常见任务

传统神经网络(CNN、RNN、DNN)要求输入格式是固定的(如24×24、128×128等)。

但在实际场景中(例如道路交通),不同城市的道路数量和节点数量都不同,即输入数据格式不固定。对此,传统神经网络不能很好地解决,但是GNN可以用来解决此类问题。 对于输入数据格式不固定的情况,GNN的常见任务有以下几种:

2.4.1 Graph级别任务

基于整个图,做分类和回归。 例如,给定一个分子结构图,判断它里面存在几个环 或者 判断该分子结构属于哪一类

2.4.2 Node与Edge级别任务

预测这个点是教练还是学员,即预测点 预测两个点之间的关系(是打架关系还是观看关系),即预测边

2.5 消息传递计算方法2.5.1 优化邻接矩阵

之前学过,邻接矩阵的大小为N*N,当节点很多的时候,邻接矩阵的大小也会特别大

为了解决这个问题,我们一般采取只保存source数组和target数组的方式

source数组即起点(起源点)数组,target数组即终点(目标点)数组

这两个数组的维度是一样的

对应位置的source和target值就可以代表一条可连接的有向边,对于没有连接关系的边则不需要保存其信息,这样就可以大大减少数据规模

2.5.2 点的特征重构

汇总 = 自身的信息 + 所有邻居点的信息

所有邻居点信息的表达有几种:

求解Sum求平均Mean求最大Max求最小Min 2.6 多层GNN的作用

层数越多,GNN的“感受野”越大,每个点考虑其他点的信息越多,考虑越全面

GNN输出特征的用处

三、GCN详解3.1 GCN基本模型概述3.1.1 卷积 vs 图卷积

卷积:卷积核平移计算

图卷积:自身信息+所有邻居信息

3.1.2 图中常见任务

3.1.3 如果获取特征

3.1.4 半监督学习【图神经网络实战】深入浅出地学习图神经网络GNN(上)(图神经网络gat)

GCN属于半监督学习(不需要每个节点都有标签都可以进行训练)

计算Loss时,只需要考虑有标签的节点即可。

为了减少有标签节点的Loss,其周围的点也会做相应的调整,这也是图结构的特点,因此GNN和GCN中,不需要所有节点都有标签也可以进行训练(当然至少需要一个节点有标签)

3.2 图卷积的基本计算方法3.2.1 GCN基本思想对每个节点计算特征然后合成每个节点的特征将合成的特征传入全连接网络进行分类 3.2.2 GCN层数

图卷积也可以做多层,但是一般不做太深层,一般只做2-3层 (类似于一种说法,你只需要认识6个人就可以认识全世界) 实验表明:GCN中,深层的网络结构往往不会带来更好的效果。 直观解释:我表哥认识的朋友的朋友的朋友的朋友认识市长,不代表我和市长关系就很好。

层数越多,特征表达就越发散

一般2-5层即可

3.2.3 图中基本组成

3.2.4 特征计算方法

3.3 邻接矩阵的变换

单位矩阵相当于给每个节点加了一条自连接的边

但是现在存在一个问题:一个节点的度越大,其做矩阵乘法后的值就越大(累加次数变多了),这种情况是不好的(相当于一个人认识的人越多,其的特征值就越大,这样不好)

为了解决这个问题,我们需要对度矩阵求倒数,相当于平均的感觉,对度数大的节点加以限制 上面的左乘相当于对行做了归一化操作,那么列也需要做归一化操作 但是又有问题了,行和列都做了归一化,那不是会存在2次归一化的情况吗(行列重叠处)

所以我们需要在度矩阵倒数那加一个0.5次方来抵消这个2次归一化的影响

3.4 GCN变换原理解读

如下图所示,假设绿色框中的人是个富人,红色框中的人是个穷人,他们只是小时候认识,穷人只认识富人,而富人认识很多人。

如果只对行做归一化,由于穷人只认识富人,所以其度为1,则其在进行特征重构的时候很大一部分信息会来自于富人,这样的模型大概率会认为穷人和富人是同一种人。显然,这是不合理的

所以,我们需要同时对行和列都进行归一化,这样不仅只考虑富人对穷人的关系,还考虑了穷人对富人的关系。

简单来说,对行做归一化考虑到了,富人对穷人来说很重要;对列作归一化,考虑到了穷人对富人可能没那么重要(因为富人的度很大,穷人的度很小,富人很可能不记得穷人了),这样相对更加合理。

3.5 GCN传播公式

softmax是作多分类常用的激活函数

四、PyTorch Geometric 库的基本使用4.1 PyTorch Geometric 的安装

注意: 千万不要直接pip install 去安装这个库!!!

进入这个GitHub网址: https://github.com/pyg-team/pytorch_geometric

进入页面后往下滑,找到如下图所示的字样,点击here

选择你电脑中已经安装的torch版本(一定要和你已经安装的torch版本一致)

怎么查看torch版本?

在Pycharm中,点击底部栏的Terminal,输入pip show torch,即可查看torch版本

选择完正确的torch版本后,会进入下面的界面,一共有4个不同的.whl文件,每一种选一个符合你的版本下载即可

例如:torch_cluster-1.5.9-cp36-cp36m-win_amd64.whl 指的是python为3.6的windows版本

我的电脑是windows的,python版本为3.8.12,所以我下载的四个包如下图所示: 下载好之后,直接pip install 你的.whl文件地址

下面是我安装时候的命令(仅供参考):

pip install C:\Users\WSKH\Desktop\torch_cluster-1.5.9-cp38-cp38-win_amd64.whlpip install ‪C:\Users\WSKH\Desktop\torch_scatter-2.0.6-cp38-cp38-win_amd64.whlpip install C:\Users\WSKH\Desktop\torch_sparse-0.6.9-cp38-cp38-win_amd64.whlpip install ‪C:\Users\WSKH\Desktop\torch_spline_conv-1.2.1-cp38-cp38-win_amd64.whl

最后,一定要等上面四步完成之后,再执行下面的操作

pip install torch-geometric4.2 数据集与邻接矩阵格式4.2.1 数据集介绍

Hello World 级别的数据集,34个节点

4.2.2 数据探索from torch_geometric.datasets import KarateClubdataset = KarateClub()print(f'Dataset:{dataset}:')print('=' * 30)print(f'Number of graphs:{len(dataset)}')print(f'Number of features:{dataset.num_features}')print(f'Number of classes:{dataset.num_classes}')print('=' * 30)data = dataset[0]# train_mask = [True,False,...] :代表第1个点是有标签的,第2个点是没标签的,方便后面LOSS的计算print(data) # Data(x=[节点数, 特征数], edge_index=[2, 边的条数], y=[节点数], train_mask=[节点数])

输出:

Dataset:KarateClub():==============================Number of graphs:1Number of features:34Number of classes:4==============================Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])4.2.3 使用networkx进行可视化展示import osfrom torch_geometric.datasets import KarateClubfrom torch_geometric.utils import to_networkximport networkx as nximport matplotlib.pyplot as plt# 画图函数def visualize_graph(G, color): plt.figure(figsize=(7, 7)) plt.xticks([]) plt.yticks([]) nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False, node_color=color, cmap="Set2") plt.show()# 画点函数def visualize_embedding(h, color, epoch=None, loss=None): plt.figure(figsize=(7, 7)) plt.xticks([]) plt.yticks([]) h = h.detach().cpu().numpy() plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2") if epoch is not None and loss is not None: plt.xlabel(f'Epoch:{epoch},Loss:{loss.item():.4f}', fontsize=16) plt.show()if __name__ == '__main__':# 不加这个可能会报错 os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' dataset = KarateClub() print(f'Dataset:{dataset}:') print('=' * 30) print(f'Number of graphs:{len(dataset)}') print(f'Number of features:{dataset.num_features}') print(f'Number of classes:{dataset.num_classes}') print('=' * 30) data = dataset[0] # train_mask = [True,False,...] :代表第1个点是有标签的,第2个点是没标签的,方便后面LOSS的计算 print(data) # Data(x=[节点数, 特征数], edge_index=[2, 边的条数], y=[节点数], train_mask=[节点数]) G = to_networkx(data, to_undirected=True) visualize_graph(G, color=data.y)

可视化结果:

4.2.4 GCN模型搭建import torchfrom torch.nn import Linearfrom torch_geometric.nn import GCNConvclass GCN(torch.nn.Module): def __init__(self, num_features, num_classes): super(GCN, self).__init__() torch.manual_seed(520) self.num_features = num_features self.num_classes = num_classes self.conv1 = GCNConv(self.num_features, 4) # 只定义子输入特证和输出特证即可 self.conv2 = GCNConv(4, 4) self.conv3 = GCNConv(4, 2) self.classifier = Linear(2, self.num_classes) def forward(self, x, edge_index): # 3层GCN h = self.convl(x, edge_index) # 给入特征与邻接矩阵(注意格式,上面那种) h = h.tanh() h = self.conv2(h.edge_index) h = h.tanh() h = self.conv3(h, edge_index) h = h.tanh() # 分类层 out = self.classifier(h) return out, h4.2.5 使用搭建好的GCN模型import osimport timefrom torch_geometric.datasets import KarateClubimport networkx as nximport matplotlib.pyplot as pltimport torchfrom torch.nn import Linearfrom torch_geometric.nn import GCNConv# 画图函数def visualize_graph(G, color): plt.figure(figsize=(7, 7)) plt.xticks([]) plt.yticks([]) nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False, node_color=color, cmap="Set2") plt.show()# 画点函数def visualize_embedding(h, color, epoch=None, loss=None): plt.figure(figsize=(7, 7)) plt.xticks([]) plt.yticks([]) h = h.detach().cpu().numpy() plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2") if epoch is not None and loss is not None: plt.xlabel(f'Epoch:{epoch},Loss:{loss.item():.4f}', fontsize=16) plt.show()class GCN(torch.nn.Module): def __init__(self, num_features, num_classes): super(GCN, self).__init__() torch.manual_seed(520) self.num_features = num_features self.num_classes = num_classes self.conv1 = GCNConv(self.num_features, 4) # 只定义子输入特证和输出特证即可 self.conv2 = GCNConv(4, 4) self.conv3 = GCNConv(4, 2) self.classifier = Linear(2, self.num_classes) def forward(self, x, edge_index): # 3层GCN h = self.conv1(x, edge_index) # 给入特征与邻接矩阵(注意格式,上面那种) h = h.tanh() h = self.conv2(h, edge_index) h = h.tanh() h = self.conv3(h, edge_index) h = h.tanh() # 分类层 out = self.classifier(h) return out, h# 训练函数def train(data): optimizer.zero_grad() out, h = model(data.x, data.edge_index) loss = criterion(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return loss, hif __name__ == '__main__': # 不加这个可能会报错 os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # 数据集准备 dataset = KarateClub() data = dataset[0] #
本文链接地址:https://www.jiuchutong.com/zhishi/288124.html 转载请保留说明!

上一篇:Web 页面之间传递参数的几种方法(html页面间传数据)

下一篇:薄雾笼罩的河流中的丹顶鹤,日本北海道 (© Paul & Paveena Mckenzie/Getty Images)(薄雾笼罩着整个森林)

  • 闲鱼怎么取消芝麻信用授权(闲鱼怎么取消芝麻应用授权)

    闲鱼怎么取消芝麻信用授权(闲鱼怎么取消芝麻应用授权)

  • word视图模式有几种(word视图模式有哪些)

    word视图模式有几种(word视图模式有哪些)

  • rx470d相当于n卡什么(rx470相当于n卡什么)

    rx470d相当于n卡什么(rx470相当于n卡什么)

  • 微信语音截图并播放怎么操作(微信语音截图并播放怎么操作苹果手机)

    微信语音截图并播放怎么操作(微信语音截图并播放怎么操作苹果手机)

  • 爱奇艺学生会员限制(爱奇艺学生会员可以登录几个设备)

    爱奇艺学生会员限制(爱奇艺学生会员可以登录几个设备)

  • 华为mate30与华为p40的区别(华为mate30与华为p40)

    华为mate30与华为p40的区别(华为mate30与华为p40)

  • 三星s20+和三星s10+对比(三星s20+和三星s10+参数对比)

    三星s20+和三星s10+对比(三星s20+和三星s10+参数对比)

  • 华为手机能下载两个微信吗(华为手机能下载ins吗)

    华为手机能下载两个微信吗(华为手机能下载ins吗)

  • 学习通怎么下载(智学网怎么下载)

    学习通怎么下载(智学网怎么下载)

  • 视频控制器vga兼容是什么(视频控制器vga兼容感叹号)

    视频控制器vga兼容是什么(视频控制器vga兼容感叹号)

  • 电脑未检测到摄像头怎么办(电脑未检测到摄像头设备是什么意思)

    电脑未检测到摄像头怎么办(电脑未检测到摄像头设备是什么意思)

  • 华为mate30pro触屏不灵敏(华为mate30pro触屏不灵敏怎样解决)

    华为mate30pro触屏不灵敏(华为mate30pro触屏不灵敏怎样解决)

  • iPhone7用pd充电器会坏吗(iphone7p可以用pd充电器吗)

    iPhone7用pd充电器会坏吗(iphone7p可以用pd充电器吗)

  • 指纹验证失败什么原因(指纹验证失败什么情况)

    指纹验证失败什么原因(指纹验证失败什么情况)

  • d盘格式化了能恢复吗(格式化d盘后能恢复数据吗)

    d盘格式化了能恢复吗(格式化d盘后能恢复数据吗)

  • 华帝燃气热水器e2什么意思(华帝燃气热水器e2故障怎么解决)

    华帝燃气热水器e2什么意思(华帝燃气热水器e2故障怎么解决)

  • 华为荣耀20nfc怎么复制门禁卡(华为荣耀20nfc怎么读取公交卡)

    华为荣耀20nfc怎么复制门禁卡(华为荣耀20nfc怎么读取公交卡)

  • word竖式除号怎么打(word怎么打除法竖式的符号)

    word竖式除号怎么打(word怎么打除法竖式的符号)

  • 锁单啥意思(锁单操作技巧视频)

    锁单啥意思(锁单操作技巧视频)

  • 苹果6s屏幕尺寸大小(苹果6s屏幕尺寸多少)

    苹果6s屏幕尺寸大小(苹果6s屏幕尺寸多少)

  • ios13黑夜模式省电吗(ios13.6夜间模式)

    ios13黑夜模式省电吗(ios13.6夜间模式)

  • 荣耀20一键清理在哪里(荣耀20一键清理可以自定义吗)

    荣耀20一键清理在哪里(荣耀20一键清理可以自定义吗)

  • 荣耀20怎么结束后台(荣耀手机怎么结束运行程序)

    荣耀20怎么结束后台(荣耀手机怎么结束运行程序)

  • 怎么通过身份证号查电话号码(怎么通过身份证号码查个人信息)

    怎么通过身份证号查电话号码(怎么通过身份证号码查个人信息)

  • 腾讯文档转word(腾讯文档转Word文档出错)

    腾讯文档转word(腾讯文档转Word文档出错)

  • 系统软件的核心是(系统软件部分的核心)

    系统软件的核心是(系统软件部分的核心)

  • series3和4的区别(series 3)

    series3和4的区别(series 3)

  • oppok3有红外线吗(oppok1手机有红外线功能吗)

    oppok3有红外线吗(oppok1手机有红外线功能吗)

  • 三星s8屏幕有残影(三星s8屏幕有残影怎么办)

    三星s8屏幕有残影(三星s8屏幕有残影怎么办)

  • 小迪安全day08信息收集-架构,搭建,WAF(小迪安全2021)

    小迪安全day08信息收集-架构,搭建,WAF(小迪安全2021)

  • 税前扣除凭证按照用途分为哪些
  • 房地产企业利息资本化的条件
  • 不动产租赁税率9%
  • 增值税退税如何做账
  • 福利费进项税额转出会计分录账务处理
  • 税务局开普票怎么开
  • 折扣折让的销售方式有哪些
  • 上月应交税金
  • 增值税专用发票的税率是多少啊
  • 出口企业是外贸企业吗
  • 会计去报税流程
  • 工商年报填错了能改吗
  • 没有认证的进项发票怎么入账
  • 增值税先征后退属于政府补助吗
  • 电子钥匙到期怎么办
  • 公司购买自用房产税如何征收
  • 暂估金额与发票金额的区别
  • 苹果系统如何访问相册
  • mac小技巧
  • 苹果助手hi
  • 调整应收账款如何做账
  • php 时间差
  • 与资产相关的政府补助,如果相关资产在使用寿命结束时
  • 预收账款调增应纳税所得额
  • 补缴增值税和滞纳税区别
  • 个人出售二手房要交增值税吗
  • 马德拉海岛
  • sass转化为css
  • php连接数据库实现登录注册
  • vue组件引入外部js
  • 发票密码区如何调整
  • vue面试题2020
  • opencv图像识别特定形状
  • 未分配利润是负数是亏损吗
  • 总公司与分公司合作协议范本
  • 生产成本有什么
  • 资金结存属于什么会计科目
  • 个人提供劳务怎么去税务局开发票
  • 支付版权使用费怎么记账
  • 运费和什么有关
  • 年终奖的个税税率表
  • 销售产品的运输费会计分录
  • sql server2012新建一个数据表
  • 普通发票一般几个点
  • 机票的退票费可以开具发票吗
  • 增加固定资产原值
  • 勾选认证能够勾选当月
  • 公司购买的车辆折旧年限
  • 进口关税,增值税是进口设备重置成本中的从属费用
  • 物业公司收物业费不开发票违法吗
  • 政府补贴资金如何记账
  • 公账发工资如何记账
  • 金融工具减值准则
  • 现金日记账1月份本年累计吗
  • 汇算清缴产生的企业所得税如何做账
  • mysql 判断
  • solaris 修改用户 主目录
  • 单网卡计算机有几个网络接口
  • ubuntu怎么安装程序
  • linux编译步骤
  • 如何显示文件后缀win10
  • virtualbox虚拟机菜单找不到了
  • linux删除u盘记录
  • Win10中SmartScreen无法设置需要系统管理员身份该怎么办?
  • ControlSet001、ControlSet002以及CurrentControlSet之间有什么区别
  • win7使用率
  • 电脑市场调查报告
  • linux使用mv命令,结果文件不见了
  • Ver、Vol、Ctty命令的使用教程
  • js显示时间并且之后秒数实时更新
  • 如何实现左侧固定,右侧自适应的布局
  • jquery fadein 源码
  • linux监控cpu使用率脚本
  • angularjs1.5
  • jQuery EasyUI之DataGrid使用实例详解
  • js中的substring
  • 日历 caldav
  • 江西省国家税务局发票查询
  • 福建原盐和自然盐有什么区别
  • 在哪里查看法律
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设