位置: IT常识 - 正文

(四)孪生神经网络介绍及pytorch实现(孪生神经网络 计算相似度)

编辑:rootadmin
(四)孪生神经网络介绍及pytorch实现

推荐整理分享(四)孪生神经网络介绍及pytorch实现(孪生神经网络 计算相似度),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:孪生神经网络 pytorch,孪生物什么意思,孪生物什么意思,孪生神经网络图像分类,孪生神经网络 pytorch,孪生神经网络图像分类,孪生神经网络模型,孪生神经网络应用,内容如对您有帮助,希望把文章链接给更多的朋友!

欢迎访问个人网络日志🌹🌹知行空间🌹🌹

孪生神经网络介绍及pytorch实现1.孪生神经网络2.孪生神经网络的损失函数2.1 Triplet Loss2.2 Contrastive Loss3.动手实现一个孪生网络3.1 网络结构3.2 损失函数3.3 数据3.4 训练结果4.SiameseNetWork的一些应用参考资料1.孪生神经网络

在深度学习领域,神经网络取得了成功。但普通的神经网络模型的训练需要大量的数据,对于一些数据有限的场景,如人脸验证,签字验证,必须考虑其他方法。

Siamese 古语表示瞿罗,即现在的泰国,如Siamese cat,之所以Siamese表示孪生,是因为19世纪瞿罗出了一对连体双胞胎,在美国玲玲马戏团做演出比较出名,因此提起Siamese即表示孪生的意思。1

孪生神经网络Siamese Network,如其名字孪生Siamese的意思即存在连体,连体即彼此共享一部分。孪生神经网络的结构也包括两个子网络,两个子网络之间共享权重。

> 图片来自于1

如上图,两个网络是同一个并共享权重,当两个子网络不共享权重时,通常定义为伪孪生神经网络。

图片来自于1

从上面的图中可以看出来,孪生神经网络有两个输入,input1和input2,因此孪生神经网络常用来通过比较两个输入特征向量的距离来衡量两个输入的相似度。早在1993年的NIPS上Yann Lecun就发表了使用孪生神经网络做签名验证的论文。现在的人脸识别应用也有基于孪生神经来做的。

孪生神经网络的优点,对于类别不平衡问题更鲁棒,更易于做集成学习(Ensemble Learning),可以从语义相似性上学习来估测两个输入的距离。孪生神经网络的缺点,由于有两个输入,两个子网,其训练相对于常规网络运算量更大,需要的时间更长。输出的结果不是概率,孪生神经网络时成对的输入,其输出是两个类间的距离而不是概率。

2.孪生神经网络的损失函数

由与孪生神经网络是计算的两个输入的相似度,距离,而不是对输入做分类,因此交叉商损失函数不适用于此种场景,孪生神经网络的常用的损失函数有Triplet Loss和Contrastive Loss。

2.1 Triplet Loss

Triplet Loss三元组损失函数,其应用见谷歌2015年发表在CVPR上的做人脸验证的论文facenet。该损失函数定义一个三元组作为输入,分别是(Xanchor,Xpositive,Xnegative)(X_{anchor},X_{positive},X_{negative})(Xanchor​,Xpositive​,Xnegative​)这三个输入的通过如下方式构成,先从训练数据集中随机选一个样本作为Anchor,再随机选取一个和Anchor属于同一类的样本作为正样本XpositiveX_{positive}Xpositive​,和一个不同类的样本作为负样本XnegativeX_{negative}Xnegative​,通过这种方式定义一个输入的三元组(Xanchor,Xpositive,Xnegative)(X_{anchor},X_{positive},X_{negative})(Xanchor​,Xpositive​,Xnegative​),将其输入到网络可以得到对应的特征向量[f(Xanchor),f(Xpositive),f(Xnegative)][f(X_{anchor}),f(X_{positive}),f(X_{negative})][f(Xanchor​),f(Xpositive​),f(Xnegative​)],Triplet Loss的目的是通过训练,使得同种类别的距离更近,不通类别的距离更大,即拉近anchor与positive推远anchor和negative,如下图:

图片来自FaceNet论文

通过这种相似度比较式的学习,模型不仅与同类别更像,还学会了与不同类别增大区分度的信息。通常定义一个α\alphaα,使得Anchor距离Negative的距离比距离Positive大α\alphaα,公式化表示为:

(四)孪生神经网络介绍及pytorch实现(孪生神经网络 计算相似度)

∣∣f(Xanchor)−f(Xnegative)∣∣−∣∣f(Xanchor)−f(Xpositive)∣∣>α||f(X_{anchor}) - f(X_{negative})|| - ||f(X_{anchor}) - f(X_{positive})|| \gt \alpha∣∣f(Xanchor​)−f(Xnegative​)∣∣−∣∣f(Xanchor​)−f(Xpositive​)∣∣>α

定义为:

L(Xanchor,Xpositive,Xnegative)=max(∣∣f(Xanchor)−f(Xpositive)∣∣−∣∣f(Xanchor)−f(Xnegative)∣∣+α,)L(X_{anchor}, X_{positive}, X_{negative}) = max(||f(X_{anchor}) - f(X_{positive})|| - ||f(X_{anchor}) - f(X_{negative})|| + \alpha, 0)L(Xanchor​,Xpositive​,Xnegative​)=max(∣∣f(Xanchor​)−f(Xpositive​)∣∣−∣∣f(Xanchor​)−f(Xnegative​)∣∣+α,0)

2.2 Contrastive Loss

衡量相似度的另一常用函数是Yann Lecun在2005年的一篇论文Dimensionality Reduction by Learning an Invariant Mapping中使用的Contrastive Loss。

Contrastive Loss的输入是一对样本,基于相似的一对对象特征距离应该更小,不相似的一对对象特征距离应该较大来计算。从数据中选一对样本(Xa,Xb)(X_a, X_b)(Xa​,Xb​),这两个样本的欧式距离表示为d=∣∣Xa−Xb∣∣2=(Xa−Xb)2d=||X_a-X_b||_2=\sqrt{({X_a-X_b})^2}d=∣∣Xa​−Xb​∣∣2​=(Xa​−Xb​)2​,则Contrastive Loss可表示为: L(Xa,Xb)=(1−Y)12d2+Y12{max(,m−d)}2L(X_a,X_b) = (1-Y)\frac{1}{2}d^2 + Y\frac{1}{2}\{max(0, m-d)\}^2L(Xa​,Xb​)=(1−Y)21​d2+Y21​{max(0,m−d)}2

Y表示(Xa,Xb)(X_a,X_b)(Xa​,Xb​)是否匹配,匹配为1不匹配为0

m是设置的安全距离,当(Xa,Xb)(X_a, X_b)(Xa​,Xb​)的距离小于mmm时,Contrasive Loss将变成0,这使得XaX_aXa​与XbX_bXb​相似而不是相同,能保证算法的泛化能力

3.动手实现一个孪生网络3.1 网络结构

这里使用Contrasive Loss定义一个孪生神经网络,网络结构如图:

这里上下两个网络使用同一个网络来实现,对于两个输入,每一步推理使用相同的权重forward两次,然后计算损失函数更新权重,这里并没有定义两个网络。为了简化训练,自定义了比较小的网络

class SiameseNetwork(nn.Module): """Custom Siamese Network """ def __init__(self): super(SiameseNetwork, self).__init__() self.cnn = nn.Sequential( nn.Conv2d(1, 128, kernel_size=5, stride=3, padding=2), # 10 nn.ReLU(inplace=True), nn.LocalResponseNorm(5, alpha=0.001, beta=.75, k=2), # TODO nn.MaxPool2d(4, stride=2), # 4 nn.Dropout2d(p=.5), ) # 12544 self.fc = nn.Sequential( nn.Linear(2048, 512), nn.ReLU(inplace=True), nn.Dropout2d(p=0.5), nn.Linear(512, 128), nn.ReLU(inplace=True), nn.Linear(128, 2) ) def forward_once(self, x): y = self.cnn(x) y = y.view(y.size()[0], -1) y = self.fc(y) return y def forward(self, x1, x2): y1 = self.forward_once(x1) y2 = self.forward_once(x2) return y1, y23.2 损失函数

损失函数使用的是前述的Contrastive Loss,其定义为:

class ContrastiveLoss(torch.nn.Module): def __init__(self, margin): super(ContrastiveLoss, self).__init__() self.margin = margin def forward(self, x1, x2, y): dist = F.pairwise_distance(x1, x2) total_loss = (1-y) * torch.pow(dist, 2) + \ y * torch.pow(torch.clamp_min_(self.margin - dist, 0), 2) loss = torch.mean(total_loss) return loss3.3 数据

这里使用的是基于MNIST数据集随机选取的1000张图像然后生成了8000对作为输入来训练的,测试时输入两张手写字图片输出其相似度。

3.4 训练结果

训练了20个epoch,损失函数值的变化趋势如下图:

由于使用的batch_size较小,迭代的次数较少,可以看到损失函数没有很好的收敛。且打开训练数据看了下自己生成的train.csv中的图像对,绝大部分label都是0,存在严重的数据不平衡问题,需要改进。在测试数据上的输出,对于有些输入可以比较好的衡量其相似度。

Predicted Distance: 0.0020178589038550854Actual Label: Different SignaturePredicted Distance: 0.0002805054828058928Actual Label: Same SignaturePredicted Distance: 0.003011130029335618Actual Label: Different SignaturePredicted Distance: 0.0018709745490923524Actual Label: Different Signature

完整代码见gitee仓库

4.SiameseNetWork的一些应用

1.签名验证Signature Verification using a “Siamese” Time Delay Neural Network

2.三胞胎网络Deep metric learning using Triplet network

3.One-ShotLearning, Siamese Neural Networks for One-shot Image Recognition

4.人脸验证Learning a Similarity Metric Discriminatively, with Application to Face Verification

参考资料1.Siamese network 孪生神经网络–一个简单神奇的结构2.FaceNet3.Contrastive Loss4.A friendly introduction to Siamese Networks

欢迎访问个人网络日志🌹🌹知行空间🌹🌹

本文链接地址:https://www.jiuchutong.com/zhishi/298737.html 转载请保留说明!

上一篇:ChatGPT 被大面积封号,到底发生什么了?

下一篇:Ubuntu22.04 下安装驱动、CUDA、cudnn以及TensorRT(ubuntu20.04.1安装)

  • 支付宝2021中秋付款码皮肤怎么领取(支付宝中秋月饼活动)

    支付宝2021中秋付款码皮肤怎么领取(支付宝中秋月饼活动)

  • ps怎么截图想要的区域(ps怎么截图想要的区域快捷键)

    ps怎么截图想要的区域(ps怎么截图想要的区域快捷键)

  •  如何让qq电话来电无声(如何让qq电话来电没声音)

    如何让qq电话来电无声(如何让qq电话来电没声音)

  • 微信注销非法请求是什么意思(微信注销非法请求的原因及解决办法)

    微信注销非法请求是什么意思(微信注销非法请求的原因及解决办法)

  • 怎么解除微信被别人同步(怎么解除微信被限制)

    怎么解除微信被别人同步(怎么解除微信被限制)

  • b站音频怎么下载到手机(b站音频怎么下载到本地)

    b站音频怎么下载到手机(b站音频怎么下载到本地)

  • 抖音里的通讯录好友不见了去哪里找(抖音里的通讯录在哪里找?)

    抖音里的通讯录好友不见了去哪里找(抖音里的通讯录在哪里找?)

  • 华为P30怎样下载优酷(华为p30怎样下载电影)

    华为P30怎样下载优酷(华为p30怎样下载电影)

  • 华为nova3怎么隐藏应用(华为nova3怎么隐藏应用图标)

    华为nova3怎么隐藏应用(华为nova3怎么隐藏应用图标)

  • 抖音赞全部作品会被限流吗(抖音赞全部作品怎么看)

    抖音赞全部作品会被限流吗(抖音赞全部作品怎么看)

  • 荣耀30有红外线功能吗(荣耀80支持红外线吗)

    荣耀30有红外线功能吗(荣耀80支持红外线吗)

  • win7支持gpt吗(win7可以gpt)

    win7支持gpt吗(win7可以gpt)

  • 美团差评申诉不成功对店铺有没有影响(美团差评申诉不通过,怎么投诉评审专员)

    美团差评申诉不成功对店铺有没有影响(美团差评申诉不通过,怎么投诉评审专员)

  • 美团众包的评论是什么时候显示(美团众包评论如何看是哪一单评论的)

    美团众包的评论是什么时候显示(美团众包评论如何看是哪一单评论的)

  • 电脑上commander的意思(电脑上的command在啥地方)

    电脑上commander的意思(电脑上的command在啥地方)

  • 华为手机克隆用数据流量吗(华为手机克隆用不了怎么办)

    华为手机克隆用数据流量吗(华为手机克隆用不了怎么办)

  • qq举报会不会被知道(qq举报会不会被发现 知乎)

    qq举报会不会被知道(qq举报会不会被发现 知乎)

  • 安卓手机怎么变成苹果系统(安卓手机怎么变成电脑模式)

    安卓手机怎么变成苹果系统(安卓手机怎么变成电脑模式)

  • 苹果虚拟机和双系统的区别(苹果虚拟机和双系统区别在哪)

    苹果虚拟机和双系统的区别(苹果虚拟机和双系统区别在哪)

  • 电脑ie系列浏览器有哪些(电脑ie系列浏览器怎么用)

    电脑ie系列浏览器有哪些(电脑ie系列浏览器怎么用)

  • 拼多多能收藏多少商品(拼多多收藏多少商品)

    拼多多能收藏多少商品(拼多多收藏多少商品)

  • 手机上怎么注销手机号(手机上怎么注销etc)

    手机上怎么注销手机号(手机上怎么注销etc)

  • 手机银行怎么交学费(手机银行怎么交短信通知费)

    手机银行怎么交学费(手机银行怎么交短信通知费)

  • 快手极速版可以发作品吗(快手极速版可以扫码登录吗)

    快手极速版可以发作品吗(快手极速版可以扫码登录吗)

  • excel表如何自动更新(Excel表如何自动调整列宽)

    excel表如何自动更新(Excel表如何自动调整列宽)

  • qq视频已过期怎么找回(qq视频已过期怎么找回来)

    qq视频已过期怎么找回(qq视频已过期怎么找回来)

  • qq怎么防撤回(qq访客记录删除了怎么恢复)

    qq怎么防撤回(qq访客记录删除了怎么恢复)

  • uniapp实现上拉加载更多(uniapp下拉)

    uniapp实现上拉加载更多(uniapp下拉)

  • 用Python来统计本机CPU利用率(python进行统计分析)

    用Python来统计本机CPU利用率(python进行统计分析)

  • 私车公用的税务风险
  • 缴纳印花税的会计凭证
  • 个人私活 要交个人所得税吗
  • 销项税额的计算方法
  • 总公司设立分公司的决定
  • 企业清算时未抵扣的进项税账务处理
  • 营改增土地增值税的计算
  • 国债逆回购收益什么时候到账
  • 物业费属于什么合同
  • 银行流动性比例要求
  • 个人独资企业是什么意思
  • 外籍人员取得数月奖金怎么交税
  • 残疾人名下有房产可以申请残疾人补贴吗?
  • 管理会计完全成本法和变动成本法例题
  • windows所有应用
  • 主营业务收入多栏式怎么填
  • 工程机械租赁公司图片
  • 收到厂家返利怎么做分录
  • 费用报销操作流程
  • 怎么计提企业所得税在哪里知道计提多少
  • 借条的标准格式 手写学生
  • linux常用命令make
  • 华为鸿蒙harmonyos官网4.0升级
  • 进口货物怎样报关
  • 默认网关不可用的解决办法
  • 怎么登明细分类账
  • 转出固定资产账务处理
  • 残疾人取得房屋所有权
  • 企业受赠业务的法律规定
  • 今天中秋节
  • 海关进口增值税怎么认证抵扣
  • 金融负债期末可以转出吗
  • 合伙企业分配股票给合伙人
  • php编辑器哪个好
  • 企业没有实缴
  • 工会经费免征三年的文件山东
  • chrome os安装到u盘
  • 遮天传游戏视频
  • css 入门
  • 房产置换怎么做账务处理
  • 个体工商户该如何开发票
  • 房屋租赁需要计增值税吗
  • python类的继承与多态
  • 5年前开的发票退货可以冲红吗?
  • 代扣代缴个人社保账务处理
  • mongodb数据库中使用哪个数据库可以设置用户名和密码
  • 购买一台电脑2400元贵吗
  • 房产税从租和从价都要交吗
  • 电子发票开错了应该怎么办?
  • 上级拨付的债券怎么做账
  • 销售送客户礼物
  • 其他货币资金期末有余额吗
  • 单位购买的化妆品怎么用
  • 进项税额转出是借方科目还是贷方科目
  • 本月冲红上月发票后的税款能抵减吗
  • 收到退回的增值税,应当作为营业外收入核算对吗
  • 季节性生产企业有哪些
  • 国有资产如何保值
  • 接受捐赠的增值税怎么处理
  • 担保公司的担保费能退吗
  • 应收票据和应付票据的区别
  • 咨询服务费计入哪里
  • 工程发票可以抵扣增值税吗
  • 投资款计入哪个科目
  • myeclipse连接mysql失败
  • 如何进入opencore引导
  • piped.exe
  • load its core dll
  • Win7开机黑屏只有鼠标,进入安全模式也是黑屏
  • windows8输入法
  • win8杀毒软件关闭
  • js的三种循环
  • jquery移动端ui框架
  • python关键字none
  • js计算字体宽度
  • jquery制作左导航特效
  • jqgrid tree
  • 湖北省叉车考试题库
  • 西安代驾平台有哪些
  • 云南2021高考改革
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设