位置: IT常识 - 正文

异构图神经网络 RGCN、RGAT、HAN、GNN-FILM + PyG实现(异构图神经网络 电影推荐)

编辑:rootadmin
异构图神经网络 RGCN、RGAT、HAN、GNN-FILM + PyG实现 背景

推荐整理分享异构图神经网络 RGCN、RGAT、HAN、GNN-FILM + PyG实现(异构图神经网络 电影推荐),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:异构图神经网络 问题,异构图神经网络推荐系统,异构图神经网络化学领域,异构图神经网络的应用,异构图神经网络推荐系统,异构图神经网络推荐系统,异构图神经网络的服装搭配系统,异构图神经网络异常点检测,内容如对您有帮助,希望把文章链接给更多的朋友!

ICDM 2022 : 大规模电商图上的风险商品检测,要求在一张异构图上跑点分类,由于是异常检测,正负样本数据集在1比10,记录一下初赛过程。

数据

过程

赛事官方开源了PyG实现的baseline,拿过来直接用于预处理数据了,将图结构进行预处理后得到pt文件,使用pt文件做后续处理:

graph = torch.load(dataset) //dataset = "xxx.pt"graph[type].x = [num_nodes , 256] 点数*特征维度graph[type].y = [num_nodes] 标签=labelgraph[type].num_nodes = 数量 graph[type].maps = id 离散化映射:针对不同的type重新从0开始记录id# 异构图特殊存边方式,需要指定两个点的种类和边的种类。graph[(source_type, edge_type, dest_type)].edge_index = (source,dest) [2, num_edges] # 借鉴GraphSage的邻居采样dataload,每次训练不使用整张图,可以分batchtrain_loader = NeighborLoader(graph, input_nodes=('要分类的type', train_idx), num_neighbors=[a] * b 往外采样b层,每层每种边a个,内存够a可以填-1 , shuffle=True, batch_size=128)for batch in train_loader():batch['item'].batch_size = 128batch['item'].x =[num, 256] 前batch_size个是要预测的点,其他为采样出来的点。batch['item'].y =[num] 前batch_size个是预测点的label,其他无用。batch = batch.to_homogeneous() 转化为同构图batch.x = [所有点数量, 256] batch.edge_idx = [2, 所有边数量] 记录所有边batch.edge_type = [所有边数量] 记录边的类型model(batch.x,batch.edge_index,batch.edge_type)RGCN

RGCN比较简单,其实就是借鉴GCN处理同构图的思路,将其运用到处理异构图上。

GCN的基本思想就是为了计算下一层i节点的embedding,拿出上一层和i相邻的节点和i节点本身的embedding,将这些embedding乘上对应的网络要学习的变化权重矩阵W,前面再乘上单位矩阵和归一化矩阵,每一层的W用同一个,类比卷积。

RGCN很简单,异构图不是有很多种边吗,我就把不同种类的边分开来,每种关系一张图,这样这张图上边都是一样的了,理所当然使用GCN共享W矩阵,求出这种关系下节点i的embedding,最后所有关系的embedding来个融合,随便加个权,来个relu激活一下完成。

from torch_geometric.nn import RGCNConvclass RGCN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, n_layers=2, dropout=0.5): super().__init__() self.convs = torch.nn.ModuleList() self.relu = F.relu self.dropout = dropout self.convs.append(RGCNConv(in_channels, hidden_channels, num_relations)) for i in range(n_layers - 2): self.convs.append(RGCNConv(hidden_channels, hidden_channels, num_relations)) self.convs.append(RGCNConv(hidden_channels, out_channels, num_relations)) def forward(self, x, edge_index, edge_type): for conv, norm in zip(self.convs, self.norms): x = norm(conv(x, edge_index, edge_type)) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training return xRGAT异构图神经网络 RGCN、RGAT、HAN、GNN-FILM + PyG实现(异构图神经网络 电影推荐)

由于RGCN每一层W都是固定的,不够灵活,所以加入attention机制,毕竟万物皆可attention。

先说一下GAT在GCN上的改动,在计算i节点的embedding时,还是拿出和它邻近的节点和它自己的embedding,对于每一个这样的节点j,将i,j节点的embedding拼接,变成两倍长度,然后算一个self-attention,好像就是一个单层前馈网络,就得到节点j相对于节点i的权重。

RGAT一样,在关系上下功夫,利用关系特征再算一个attention。 最后两者做融合 RGAT可以看成是RGCN进化版,在attention不起作用的时候会退化成RGCN。

但实战和RGCN不分伯仲,甚至在本次竞赛的场景中逊色于RGCN。原因见论文:

RGAT通过attention机制比较好的完成任务之后,很难在损失机制反馈的作用下找到那个把attention设置成归一化常数后效果更好的点。RGCN在一些任务上会通过记忆样本的方式提升效果,但是RGAT模型更复杂发生这种情况的概率更低。from torch_geometric.nn import RGATConvclass RGAT(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, n_layers=2, n_heads=3): super().__init__() self.convs = torch.nn.ModuleList() self.relu = F.relu self.convs.append(RGATConv(in_channels, hidden_channels, num_relations, heads=n_heads, concat=False)) for i in range(n_layers - 2): self.convs.append(RGATConv(hidden_channels, hidden_channels, num_relations, heads=n_heads, concat=False)) self.convs.append(RGATConv(hidden_channels, hidden_channels, num_relations, heads=n_heads, concat=False)) self.lin1 = torch.nn.Linear(hidden_channels, out_channels) def forward(self, x, edge_index, edge_type): for i, conv in enumerate(self.convs): x = conv(x, edge_index, edge_type) x = x.relu_() x = F.dropout(x, p=0.2, training=self.training x = self.lin1(x) return xHeterogeneous Graph Attention Network (HAN HGAT)

根据专家经验设置多条matapath(路径):点、边、点、边、点…

针对不同的matapath,节点i针对路径拿到其所有邻居节点j。

1.点和点计算attention并求和。使用多头注意力机制。

2.所有关系要聚合时算一个attention,其中q,w,b共享。 实验中效果很差,可能是我matapath设置的不好吧,而且多头注意力训练时间也太久了,我RGCN一个epoch只要5min,它要480min。

from torch_geometric.nn import HANConvlabeld_class = 'item'class HAN(torch.nn.Module): def __init__(self, in_channels: Union[int, Dict[str, int]], out_channels: int, hidden_channels=16, heads=4, n_layers=2): super().__init__() self.convs = torch.nn.ModuleList() self.relu = F.relu self.convs.append(HANConv(in_channels, hidden_channels, heads=heads, dropout=0.6, metadata=metada)) for i in range(n_layers - 1): self.convs.append(HANConv(hidden_channels, hidden_channels, heads=heads, dropout=0.6, metadata=metada)) self.lin = torch.nn.Linear(hidden_channels, out_channels) def forward(self, x_dict, edge_index_dict): for i, conv in enumerate(self.convs): x_dict = conv(x_dict, edge_index_dict) x_dict = self.lin(x_dict[labeled_class]) return x_dict GNN-Film(线性特征调整)

对比RGCN,改动的点与RGAT类似,同样想使得权重有所变化。加入了一个简单的前馈网络: 优点在于他在算权重的时候,加了一个仿射变换,相当于是用神经网络去计算参数。再用b和y去作为权重调整embedding。

实验中效果出奇的好,训练快,效果超越RGCN。

class GNNFilm(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, n_layers, dropout=0.5): super().__init__() self.dropout = dropout self.convs = torch.nn.ModuleList() self.convs.append(FiLMConv(in_channels, hidden_channels, num_relations)) for _ in range(n_layers - 1): self.convs.append(FiLMConv(hidden_channels, hidden_channels, num_relations)) self.norms = torch.nn.ModuleList() for _ in range(n_layers): self.norms.append(BatchNorm1d(hidden_channels)) self.lin_l = torch.nn.Sequential(OrderedDict([ ('lin1', Linear(hidden_channels, int(hidden_channels//4), bias=True)), ('lrelu', torch.nn.LeakyReLU(0.2)), ('lin2', Linear(int(hidden_channels//4),out_channels, bias=True))])) def forward(self, x, edge_index, edge_type): for conv, norm in zip(self.convs, self.norms): x = norm(conv(x, edge_index, edge_type)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.lin_l(x) return x总结

RGCN、RGAT、GNN-FILM代码替换十分简单,训练代码完全不用动,只要改模型代码即可,完全可以三者都尝试效果,HAN慎用,效果太吃matapath的设置,训练时间还长,不值得。

本文链接地址:https://www.jiuchutong.com/zhishi/283704.html 转载请保留说明!

上一篇:口腔发炎怎么办(口腔发炎怎么办最快最有效的方法)

下一篇:亡灵节上点缀公墓的万寿菊,墨西哥米却肯州 (© Irwin Barrett/Design Pics/Alamy)(亡灵节mid)

  • vivo怎么设置动态壁纸锁屏(vivo怎么设置动态锁屏)

    vivo怎么设置动态壁纸锁屏(vivo怎么设置动态锁屏)

  • 孩子辽事通怎么注册

    孩子辽事通怎么注册

  • 惠普电脑怎么拍照(惠普电脑怎么拍照没有拍照键)

    惠普电脑怎么拍照(惠普电脑怎么拍照没有拍照键)

  • 淘宝如何修改性别(淘宝如何修改性别和价格)

    淘宝如何修改性别(淘宝如何修改性别和价格)

  • 微信互删 头像是否更新(微信互删头像变灰)

    微信互删 头像是否更新(微信互删头像变灰)

  • 投屏网速慢怎么办(投屏网速不稳定)

    投屏网速慢怎么办(投屏网速不稳定)

  • 钉钉投屏到电视上算不算观看时间(钉钉投屏到电视投屏码怎么获得)

    钉钉投屏到电视上算不算观看时间(钉钉投屏到电视投屏码怎么获得)

  • 淘宝订单多少天自动确认收货(淘宝订单多少天后不能申请售后)

    淘宝订单多少天自动确认收货(淘宝订单多少天后不能申请售后)

  • 华为荣耀9x怎么截屏(华为荣耀9x怎么取卡)

    华为荣耀9x怎么截屏(华为荣耀9x怎么取卡)

  • iphone锁屏后qq下线了(苹果手机锁屏之后qq还会在线吗)

    iphone锁屏后qq下线了(苹果手机锁屏之后qq还会在线吗)

  • 京东学生认证会毕业自动取消吗(京东学生认证会有什么影响吗)

    京东学生认证会毕业自动取消吗(京东学生认证会有什么影响吗)

  • 钉钉撤销能看到痕迹吗(钉钉撤销消息了还能看见吗)

    钉钉撤销能看到痕迹吗(钉钉撤销消息了还能看见吗)

  • vivo x7plus上市时间(vivox7plus什么时候出来的)

    vivo x7plus上市时间(vivox7plus什么时候出来的)

  • 哪款ipad可以插卡(哪款ipad可以插手机卡打电话的)

    哪款ipad可以插卡(哪款ipad可以插手机卡打电话的)

  • 苹果手机如何设置面容解锁(苹果手机如何设置小圆点快捷键)

    苹果手机如何设置面容解锁(苹果手机如何设置小圆点快捷键)

  • 华为mate30有耳机孔吗(华为mate30有耳机吗)

    华为mate30有耳机孔吗(华为mate30有耳机吗)

  • 如何判断耳机是否漏音(如何判断耳机是不是airpods)

    如何判断耳机是否漏音(如何判断耳机是不是airpods)

  • 苹果手机怎么连蓝牙耳机(苹果手机怎么连接车载carplay)

    苹果手机怎么连蓝牙耳机(苹果手机怎么连接车载carplay)

  • 惠普126nw和132nw区别(惠普126nw和132nw和132snw)

    惠普126nw和132nw区别(惠普126nw和132nw和132snw)

  • 苹果11正式发售时间(苹果正式发售能买到吗)

    苹果11正式发售时间(苹果正式发售能买到吗)

  • Oppo reno的操作系统是那种(oppo手机的操作系统)

    Oppo reno的操作系统是那种(oppo手机的操作系统)

  • xsmax的小白点在哪(xsmax白点多功能键)

    xsmax的小白点在哪(xsmax白点多功能键)

  • 快手发布失败怎么回事(快手作品发送失败怎么找回)

    快手发布失败怎么回事(快手作品发送失败怎么找回)

  • 备份计算机需要怎么做?(备份计算机需要多久)

    备份计算机需要怎么做?(备份计算机需要多久)

  • 奥杜邦中心的一只靛蓝彩鹀,美国宾夕法尼亚州 (© Vicki Jauron/Getty Images)(奥杜邦的祈祷经典语录)

    奥杜邦中心的一只靛蓝彩鹀,美国宾夕法尼亚州 (© Vicki Jauron/Getty Images)(奥杜邦的祈祷经典语录)

  • opencv调用yolov7 yolov7 c++ yolov7转onnx opencv调用yolov7 onnx(opencv调用yolov8)

    opencv调用yolov7 yolov7 c++ yolov7转onnx opencv调用yolov7 onnx(opencv调用yolov8)

  • c语言中assert函数的使用注意(c语言中asin)

    c语言中assert函数的使用注意(c语言中asin)

  • 会计新手如何学会收款流程
  • 运动手环的税收分类编码是
  • 公司公积金缴纳比例一般来说是多少?
  • 总公司是一般纳税人吗
  • 其他收入月末需要结账吗
  • 债权转增资本应缴纳什么税
  • 3项经费计提比例2015
  • 以前年度损益调整属于哪类科目
  • 开发成本可以计增值税吗
  • 资产收益权转让产品
  • 电梯维修增值税
  • 出口押汇与打包押汇区别
  • 公司发年终奖怎么发朋友圈
  • 企业所得税政策最新2023税率
  • 企业员工奖励款怎么做账
  • 全国增值税发票查询平台 手机版
  • 大中小微企业划分标准2023年
  • 长期股权投资账务处理
  • 建筑工程总包分包的内容
  • 工程结算收入以前年度多结转收入怎么处理?
  • 发票金额大于开票金额
  • 错账查找的方法
  • 网速怎么限制10mb以内
  • php解析原理
  • sistray.exe - sistray是什么进程 有什么用
  • 苹果mac系统怎么更新最新版本
  • 在建工程进项税额抵扣规定
  • PHP:curl_setopt_array()的用法_cURL函数
  • PHP:is_uploaded_file()的用法_Filesystem函数
  • 代码怎么用?
  • 相思树学名叫什么
  • 发票多久过期不能开
  • thinkphp框架介绍
  • 实收资本(或股本)是什么意思
  • php高并发api接口怎么处理
  • vue面试题视频
  • elementui动态表单数据回显
  • php的序列化操作生成的哪种格式
  • 会计证的作用和用途
  • 进项税额转出加计抵减会计分录
  • 评价股权转让要交什么税
  • 企业折旧申报备案怎么写
  • mongo mysql区别
  • sql2017附加数据库
  • 财务报表是指的什么内容
  • 什么是指企业的市场营销活动发生影响的各种因素的总和
  • 企业信用公示的时候医疗和生育怎么分开计算
  • 城市基础设施配套费征收管理规定
  • 账面价值和公允价值的关系
  • 企业哪些行为可以避税
  • 资产减值损失的科目编码
  • 预付账款的账务处理过程
  • 税务销售滞后是什么意思
  • 营业外收入是否影响营业利润
  • 员工工资占公司收入
  • 购进货物取得
  • 加盟店直营店什么意思
  • 一般户和基本户怎么使用最好
  • 私营企业归谁管
  • 明细账设置是什么意思
  • win7咋样
  • 服务器centos版本选择
  • winload是什么
  • 丢失acui16.dll
  • 因为你的策略组阻止
  • 怎么检测软件有没有毒
  • win7系统鼠标右键无法弹出菜单
  • xp如何一键还原系统还原
  • linux中安装jdk1.8
  • win8怎么样的
  • 苹果mac最新的系统
  • bat批处理命令大全
  • jquery设计模式
  • 原生js实现ajax步骤
  • javascript高级程序设计pdf下载
  • 新浪微博手机客户端下载
  • 钢结构蔬菜大棚造价多少钱一平方
  • 税务注销核对发票怎么办
  • 广西税务局发票查验平台
  • 中国进口车关税为什么那么贵
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设