位置: IT常识 - 正文

(二)元学习算法MAML简介及代码分析(二元运算例子)

编辑:rootadmin
(二)元学习算法MAML简介及代码分析

推荐整理分享(二)元学习算法MAML简介及代码分析(二元运算例子),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:二元运算什么意思,二元计算题及答案过程,二元运算,二元计算题及答案过程,二元计算题及答案过程,二元计算题及答案过程,二元计算题及答案过程,二元运算,内容如对您有帮助,希望把文章链接给更多的朋友!

欢迎访问个人网络日志🌹🌹知行空间🌹🌹

元学习算法MAML简介1.元学习(meta learning)2.模型无关元学习2.1 元学习问题建模2.2 MAML算法3.将MAML应用到回归分类任务上的算法流程4.代码解读参考资料

论文: Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks Chelsea

代码: https://github.com/cbfinn/maml

ICML2017的一篇论文,作者Chelsea Finn是斯坦福的老师,一不小心去作者主页看了下,MIT和伯克利的学生,真强。_

元学习MAML论文介绍

模型无关元学习算法,即Model-Agnostic Meta-Learning Algorithm(MAML)。

1.元学习(meta learning)

元学习即学会学习,区别与普通的深度学习过程。普通的深度学习具体到某一任务,如图像分类,即训练一个模型实现一个数据集内的图像分类,这种方法有一定的局限性,即模型只能在当前任务(task)上工作,不能应用到其他任务。譬如基于手写字识别数据集训练的分类模型不能用来实现猫和狗的分类。有没有一种方法,可以学会完成分类这一任务,不针对具体是实现哪些对象的分类,学会分类任务后再基于少量的具体数据训练学会是具体给猫狗分类还是给手写字分类。相当于说一个模型实现了原来多个模型的功能。

元学习训练模型是为了获得一个可以快速应用到小样本数据的新任务上的模型,元学习通过初步训练获得模型比较好的初值,再基于初值对具体任务在小样本训练数据上少量更新权重即可取得好的效果。

元学习还可以理解成是寻找一组具有较高敏感度的参数,基于找到的参数,只需要进行少量的迭代即可在新的任务上取得理想的结果。

元学习可应用于训练数据有限的Few-Shot Learning任务。

2.模型无关元学习2.1 元学习问题建模

元学习是在一系列任务上学习,目标是学习得到一个比较敏感的模型,使该模型能够基于小样本数据简单训练快速应用到新任务上。也就是说,元学习将一系列学习任务当作训练样本。

譬如,识别一个动物是不是狗是任务T1T_1T1​,识别一个手写数字是不是9是任务T2T_2T2​,识别一辆车是不是坦克是任务T3T_3T3​,普通的学习方法会针对每个训练一个模型,也是基于前述的任务要训练3个模型分别完成。观察前面的三个任务T1,T2,T3T_1,T_2,T_3T1​,T2​,T3​具有共性,即都是识别分类任务,能不能有一种通用模型可以学习识别分类这一任务,然后再基于少量的数据对通用模型微调即可快速应用的新的类似任务。如基于T1,T2,T2T1,T2,T2T1,T2,T2使模型学会分类能力,然后提供少量的是否是飞机的训练数据,即可快速学会判断天空中的一个物体是否是飞机。

使用数学公式描述:

(二)元学习算法MAML简介及代码分析(二元运算例子)

单个任务表示为: T={L(X1,a1,...,XH,aH),q(X1),q(Xt+1∣Xt,at),H}T=\{L(X_1,a_1,...,X_H,a_H),q(X_1),q(X_{t+1}|X_t,a_t),H\}T={L(X1​,a1​,...,XH​,aH​),q(X1​),q(Xt+1​∣Xt​,at​),H}

XXX是输入aaa是输出LLL是损失函数q(X1)q(X_1)q(X1​)是初始输入变量的概率分布q(Xt+1∣Xt,at)q(X_{t+1}|X_t,a_t)q(Xt+1​∣Xt​,at​)是输入变量的状态转移分布HHH输入变量序列的长度,对于监督学习问题,其值为1,应用在强化学习等中。L(X1,a1,...,XH,aH)→RL(X_1,a_1,...,X_H,a_H) \rightarrow \RL(X1​,a1​,...,XH​,aH​)→R是针对具体任务的损失函数,如回归问题通常是均方误差(Mean Square Error, MSE),分类问题通常是交叉商(Cross Entropy, CE)。

在元学习(meta-learning)中,考虑多个任务TTT的分布为p(T)p(T)p(T),这正是元学习模型要学习的目标。具体的任务TiT_iTi​是从任务分布p(T)p(T)p(T)中取样的,模型的训练基于任务TiT_iTi​的KKK个训练样本和任务TiT_iTi​的损失函数LiL_iLi​。任务TiT_iTi​的测试误差,将作为元学习模型的训练误差。

上图中∇L1,∇L2,∇L3\nabla L_1,\nabla L_2,\nabla L_3∇L1​,∇L2​,∇L3​分别表示任务T1,T2,T3T_1,T_2,T_3T1​,T2​,T3​上的损失函数梯度,θ1⋆,θ2⋆,θ3⋆\theta_1^\star,\theta_2^\star,\theta_3^\starθ1⋆​,θ2⋆​,θ3⋆​分别表示具体到任务T1,T2,T3T_1,T_2,T_3T1​,T2​,T3​上的参数,θ\thetaθ是元学习模型的参数。

2.2 MAML算法

算法中参数更新分成两步,一次是更新 θ′\theta'θ′,之后才是更新θ\thetaθ。这和元学习的的定义相关。θ′\theta'θ′的更新是在具体某个Taski{Task}_iTaski​上学习时发生的,而元学习的目标是找到一组参数θ\thetaθ能够对多个任务TaskTaskTask都具有表征能力。所以thetathetatheta的更新过程分成了两个,先是针对具体任务TaskiTask_iTaski​的更新优化后是针对元学习模型的优化。

第一步,针对任务TiT_iTi​的模型优化为:

θ′=θ−α∇θLTi(fθ)\theta'=\theta-\alpha\nabla_{\theta}L_{T_i}(f_\theta)θ′=θ−α∇θ​LTi​​(fθ​)

fθf_\thetafθ​表示元学习模型

第二步,针对元学习模型的优化为:

minθ∑Ti∼p(T)LTi(fθ′)=∑Ti∼p(T)LTi(fθ−α∇θLTi(fθ))θ←θ−β∇θ∑Ti∼p(T)LTi(fθ′)\mathop{min}\limits _\theta \sum\limits_{T_i\sim p(T)}L_{T_i}(f_\theta')=\sum\limits_{T_i\sim p(T)}L_{T_i}(f_{\theta-\alpha\nabla_{\theta}L_{T_i}(f_\theta)}) \\ \\ \theta \leftarrow \theta - \beta\nabla_\theta\sum\limits_{T_i\sim p(T)}L_{T_i}(f_\theta')θmin​Ti​∼p(T)∑​LTi​​(fθ′​)=Ti​∼p(T)∑​LTi​​(fθ−α∇θ​LTi​​(fθ​)​)θ←θ−β∇θ​Ti​∼p(T)∑​LTi​​(fθ′​)

3.将MAML应用到回归分类任务上的算法流程

方程2和方程3分别是均方误差和交叉熵。

4.代码解读

MAML原作者的代码是基于tensorflow 1.x版本实现的,结构比较清晰。 模型封装了一个MAML类,数据的加载在类DataGenerator中。

main.py的train函数中定义了metatrain的过程:

# metatrain_iterations是元学习模型训练的迭代此数for itr in range(resume_itr, FLAGS.pretrain_iterations + FLAGS.metatrain_iterations): feed_dict = {} # not for omniglot if 'generate' in dir(data_generator): batch_x, batch_y, amp, phase = data_generator.generate() if FLAGS.baseline == 'oracle': batch_x = np.concatenate([batch_x, np.zeros([batch_x.shape[0], batch_x.shape[1], 2])], 2) for i in range(FLAGS.meta_batch_size): batch_x[i, :, 1] = amp[i] batch_x[i, :, 2] = phase[i] """ # a: training data for inner gradient, # b: test data for meta gradient 这里 数据被分成两部分`inputa`和`inputb` `inputa`用来训练针对具体任务的模型,更新其权重 `inputb`用来测试基于`inputa`训练的模型,并计算对具体任务的模型在`intputb`的`losses` `inputb`上的测试`loss`用来更新元模型,具体实现见`maml.py`中`task_metalearn`函数 """ inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :] labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :] inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :] # b used for testing labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :] feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb} if itr < FLAGS.pretrain_iterations: # 前n步,预训练时只使用`loassa`更新元学习模型 input_tensors = [model.pretrain_op] else: input_tensors = [model.metatrain_op] ... result = sess.run(input_tensors, feed_dict)

在 MAML类construct_model函数中定义有task_metalearn函数,在这个函数中有使用num_updates参数,num_updates参数表示train函数中的每个元模型训练迭代中针对某个任务的模型迭代次数,针对某个任务的模型每更新一次,在测试数据inputb上计算1次losses,更新某个任务的模型num_updates次后,得到长度为num_updates的list lossesb,再用lossesb来更新元模型。

def task_metalearn(inp, reuse=True): """ Perform gradient descent for one task in the meta-batch. """ inputa, inputb, labela, labelb = inp task_outputbs, task_lossesb = [], [] if self.classification: task_accuraciesb = [] task_outputa = self.forward(inputa, weights, reuse=reuse) # only reuse on the first iter task_lossa = self.loss_func(task_outputa, labela) grads = tf.gradients(task_lossa, list(weights.values())) if FLAGS.stop_grad: grads = [tf.stop_gradient(grad) for grad in grads] gradients = dict(zip(weights.keys(), grads)) fast_weights = dict(zip(weights.keys(), [weights[key] - self.update_lr*gradients[key] for key in weights.keys()])) output = self.forward(inputb, fast_weights, reuse=True) task_outputbs.append(output) task_lossesb.append(self.loss_func(output, labelb)) for j in range(num_updates - 1): loss = self.loss_func(self.forward(inputa, fast_weights, reuse=True), labela) grads = tf.gradients(loss, list(fast_weights.values())) if FLAGS.stop_grad: grads = [tf.stop_gradient(grad) for grad in grads] gradients = dict(zip(fast_weights.keys(), grads)) fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.update_lr*gradients[key] for key in fast_weights.keys()])) output = self.forward(inputb, fast_weights, reuse=True) task_outputbs.append(output) task_lossesb.append(self.loss_func(output, labelb)) task_output = [task_outputa, task_outputbs, task_lossa, task_lossesb]

训练结束得到元模型后,要将元模型应用到具体任务时,要先根据提供的样本数据(x,y)对元模型进行微调test_num_updates后,再使用微调后的模型在测试数据上输出测试结果,其过程参照task_metalearn。这也就能解释测试时所用的类在训练时是没有的,为什么测试时模型可以输出测试的类别。正因为模型在测试时有个在少量测试数据上的微调的过程,可以理解成元学习模型先训练得到一个预训练权重,然后再在少量新的其他任务的训练数据上少里训练,然后在新任务的测试数据上验证。

类别为a,b的训练数据训练元学习模型微调fast_learning类别为c,d的测试数据<少量>类别为c,d的测试数据<大量>测试参考资料Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解一文入门元学习(Meta-Learning)(附代码)

欢迎访问个人网络日志🌹🌹知行空间🌹🌹

本文链接地址:https://www.jiuchutong.com/zhishi/300380.html 转载请保留说明!

上一篇:Python深度学习实战:人脸关键点(15点)检测pytorch实现

下一篇:目标检测:Faster-RCNN算法细节及代码解析(目标检测yolo算法)

  • 小米9号电动车(小米9)(小米9号电动车价格)

    小米9号电动车(小米9)(小米9号电动车价格)

  • 红米手表2屏幕尺寸(红米手表2屏幕多少钱)

    红米手表2屏幕尺寸(红米手表2屏幕多少钱)

  • 电脑蓝屏0x0000024(电脑蓝屏0x0000024怎么解决的)

    电脑蓝屏0x0000024(电脑蓝屏0x0000024怎么解决的)

  • 怎么以文件的形式发送视频(怎么以文件的形式发照片)

    怎么以文件的形式发送视频(怎么以文件的形式发照片)

  • 任天堂日版和国行区别

    任天堂日版和国行区别

  • cad曲线快捷键(cad 曲线快捷键)

    cad曲线快捷键(cad 曲线快捷键)

  • 微信朋友圈发字不发图片怎么发(微信朋友圈发字怎么发)

    微信朋友圈发字不发图片怎么发(微信朋友圈发字怎么发)

  • 51200mb是多少g流量(51200兆是多少g流量)

    51200mb是多少g流量(51200兆是多少g流量)

  • oppo实名认证能解除吗(oppo的实名认证)

    oppo实名认证能解除吗(oppo的实名认证)

  • 为什么有的qq没有匹配聊天(为什么有的QQ没有随心贴)

    为什么有的qq没有匹配聊天(为什么有的QQ没有随心贴)

  • 不是好友怎么举报(不是好友怎么举报朋友圈)

    不是好友怎么举报(不是好友怎么举报朋友圈)

  • 手机8p是什么手机(手机说的8p是什么意思)

    手机8p是什么手机(手机说的8p是什么意思)

  • eprom是指什么(eprom什么意思)

    eprom是指什么(eprom什么意思)

  • 拼多多卸载不掉(拼多多卸载不掉只能移除)

    拼多多卸载不掉(拼多多卸载不掉只能移除)

  • 声卡效果种类是什么(声卡效果都有哪些)

    声卡效果种类是什么(声卡效果都有哪些)

  • 手机qq桌面怎么设置(手机qq桌面怎么设置密码)

    手机qq桌面怎么设置(手机qq桌面怎么设置密码)

  • 抖音上传4分钟视频教程(抖音上传4分钟有收益吗)

    抖音上传4分钟视频教程(抖音上传4分钟有收益吗)

  • 小米9手机怎么截屏(小米9手机怎么截长图)

    小米9手机怎么截屏(小米9手机怎么截长图)

  • 华为手机拦截的电话在哪里(华为手机拦截的短信怎么能不显示)

    华为手机拦截的电话在哪里(华为手机拦截的短信怎么能不显示)

  • 支付宝账户就是手机号码吗(支付宝账户就是余额吗)

    支付宝账户就是手机号码吗(支付宝账户就是余额吗)

  • 打印机缩印怎么设置(打印机缩印怎么设置纸张方向)

    打印机缩印怎么设置(打印机缩印怎么设置纸张方向)

  • 怎样把照片改成500k(怎样把照片改成jpg格式)

    怎样把照片改成500k(怎样把照片改成jpg格式)

  • 小米手环可以显示具体信息吗(小米手环可以显示在微信运动上吗)

    小米手环可以显示具体信息吗(小米手环可以显示在微信运动上吗)

  • 荣耀play支持人脸识别吗(华为荣耀play人脸识别怎么设置)

    荣耀play支持人脸识别吗(华为荣耀play人脸识别怎么设置)

  • nfc开着安全吗(nfc开着有什么用)

    nfc开着安全吗(nfc开着有什么用)

  • 苹果7跟6外观有区别吗(苹果7跟6外观有啥区别)

    苹果7跟6外观有区别吗(苹果7跟6外观有啥区别)

  • 手机版爱奇艺怎么没有评论了(手机版爱奇艺怎么下载电影)

    手机版爱奇艺怎么没有评论了(手机版爱奇艺怎么下载电影)

  • 魅族16是不是16th(魅族16是不是超声波指纹)

    魅族16是不是16th(魅族16是不是超声波指纹)

  • 常用的前端大屏 适配方案(常用的前端大屏软件)

    常用的前端大屏 适配方案(常用的前端大屏软件)

  • 留抵税额是什么意思啊
  • 退个税app操作
  • 给职工租房的房租怎么进行账务处理?
  • 运输开票的税点是多少
  • 增值税普票可以开给个人吗
  • 公司为员工买零食
  • 有限公司能否申请破产
  • 会计科目已受控于应收应付系统
  • 银行存款利息收入要交增值税吗
  • 票面3个点什么意思
  • 企业注销未抵扣完的进项税
  • 已经认证的发票红冲发票需要收回原发票吗
  • 接受捐赠收入会计利润含税吗
  • 股权收购的好处
  • 进项结构明细表怎么做
  • 职工报销差旅费会计科目
  • 正确解读《非居民金融账户涉税信息尽职调查管理办法》
  • 小规模免增值税印花税用交吗
  • 公司购买二手车怎么抵税
  • 企业计提的工资薪金支出可以在税前扣除
  • 小规模未开票收入要交增值税吗
  • 积分兑换现金消费的会计分录
  • 什么是大头小头
  • 清算期间未申报债权
  • 支付工会经费
  • 安全系统不起作用或未正确安装 cad2016
  • 广告费和业务宣传费
  • linux怎么查找
  • win11怎么用win10界面
  • 事业单位专项资金包括哪些内容
  • php发送邮件代码
  • 债券的回购
  • 最高成本的手机是哪款
  • 简单了解航天员的生活
  • 美轮美奂的对象是什么
  • 享受所得税优惠情况说明
  • 增值税缓息是什么意思
  • auto learn
  • php 邮件发送
  • 老生常谈含义
  • 安装elipse教程
  • win10自带的重装能彻底清除上网记录和u盘记录吗
  • 命令行mkdir创建文件夹
  • 利息收入交所得税吗
  • 员工业余自学
  • 银行开出的承兑怎么兑现
  • wndgui降级
  • 业务预算包括直销费用吗
  • python字符串如何换行
  • 一般纳税人招待费扣除标准
  • 货物或应税劳务名称怎么填
  • 商品流通企业的进货费用
  • 应收账款损失率计算公式
  • 可供出售金融资产现在叫什么
  • 车辆抵押贷款怎么办理
  • 珠宝行业的会计处理方式
  • 固定资产损失税前扣除备查资料有哪些
  • 以前年度的固定资产入成原材料了怎么办
  • 税交多了可以退吗
  • 飞机票保险发票是什么样子的
  • 公司人才账户有什么用
  • 工程结算和工程竣工决算的区别
  • 供货商做产品配送怎么做
  • 纳税筹划有哪些特点以及原则?
  • 电脑windows怎么查
  • 电脑开机绿
  • linux修改环境变量后需要重启吗
  • windows8如何使用
  • windows7开机磁盘检查怎么取消
  • cocos2d开发的知名游戏
  • js填写input
  • android 自定义drawable
  • js实现物体移动
  • jquery移动div到另一个div中
  • 批处理怎么学
  • 简单阐述javascript的主要作用
  • jquery示例
  • 怎么打印历史发票
  • 甘肃税务局电子发票怎么开
  • 税务系统全面从严
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设