位置: IT常识 - 正文

超参数调优框架optuna(可配合pytorch)(超参数设置)

编辑:rootadmin
超参数调优框架optuna(可配合pytorch) 目录前言一、optuna的使用流程二、结果可视化三、pytorch代码使用optuna前言

推荐整理分享超参数调优框架optuna(可配合pytorch)(超参数设置),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:超参数设置,模型超参数调整,参数调优方法,超参数调优的作用,超参数选择,超参数调节,超参数优化算法,超参数优化,内容如对您有帮助,希望把文章链接给更多的朋友!

在深度学习快速发展的今天,对于不同深度学习模型的超参数优化(hyperparameter optimization),始终是一个比较头痛的问题。在超参较少的情况下,grid search是比较常见的方式,但是随着超参数量的不断增多,特别是对于神经网络而言,训练过程的超参和NN本身的超参组成的参数空间是巨大的,grid search方法会消耗巨大的资源,而且效果很差,因此寻找一个“机器炼丹”的框架十分必要。

optuna 是一个十分常用的超参数调优框架,具有操作简单,嵌入式强和动态调整参数空间等优点。另外还有其他框架也可以进行超参优化,如李沐老师提到的automl等。

一、optuna的使用流程

首先需要在命令行 pip install optuna 载入这个第三方库,载入之后import即可。

optuna中需要注意几个关键的名词: trail::一次实验 study::一次学习过程(包括多次实验)

import optunadef obj(trail):x = trail.suggest_float('x',1,5)return (x-3)*(x-3)stu = optuna.creat_study(study_name = 'test', direction = 'minimize')stu.optimize(obj, n_trials = 50)print(study.best_params)print(study.best_trial)print(study.best_trial.value)超参数调优框架optuna(可配合pytorch)(超参数设置)

该段实例代码中,函数obj定义一个含参数的需要优化的模块,带调整的超参数为 ‘x’ ,返回值为该模块的 objective value。超参x的类型为float,可调整空间为 [1,5] 左右都闭区间,常用的还有suggest_int表示整型,suggest_categorical表示字符串集合。

trail.suggest_int('name', 10, 50)trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh'])

study表示一个学习过程,direction参数为“minimize”表示对函数obj 的返回值(同时也是每次trial的objective value)向最小的方向优化。

二、结果可视化

optuna.visualization中包含了丰富的可视化工具。比较推荐使用的是以下三个:

optuna.visualization.plot_param_importances(stu).show()optuna.visualization.plot_optimization_history(stu).show()optuna.visualization.plot_slice(stu).show()

plot_param_importances 展示各个超参数对结果影响的重要性

plot_optimization_history 展示在n_trail 个trail中每次的objective value和当前的最优解

plot_slice 展示每个超参数在所有trail中取值的分布,以散点图的形式

三、pytorch代码使用optuna

在pytorch构建的MLP中进行使用,可以看到该调参框架是十分灵活的,可以设置训练参数,如batchsize,learning rate,也可也设置NN的参数,如隐藏层数目,激活函数类型等。

import torchfrom torch import nn, optimfrom torch.utils.data import DataLoaderfrom torch.autograd import Variable # 获取变量import optunadef train(batch_size, learning_rate, lossfunc, opt, hidden_layer, activefunc, weightdk,momentum): # 选出一些超参数 trainset_num = 800 testset_num = 50 train_dataset = myDataset(trainset_num) test_dataset = myDataset(testset_num) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) # 创建CNN模型, 并设置损失函数及优化器 model = MLP(hidden_layer, activefunc).cuda() # print(model) if lossfunc == 'MSE': criterion = nn.MSELoss().cuda() elif lossfunc == 'MAE': criterion = nn.L1Loss() # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weightdk) optimizer =optim.RMSprop(model.parameters(),lr=learning_rate,weight_decay=weightdk, momentum=momentum) # 训练过程 for epoch in range(num_epoches): # 训练模式 model.train() for i, data in enumerate(train_loader): inputs, labels, _ = data inputs = Variable(inputs).float().cuda() labels = Variable(labels).float().cuda() # 前向传播 out = model(inputs) # 可以考虑加正则项 train_loss = criterion(out, labels) optimizer.zero_grad() train_loss.backward() optimizer.step() model.eval() testloss = test() #返回测试集合上的MAE print('Test MAE = ', resloss) return reslossdef objective(trail): batchsize = trail.suggest_int('batchsize', 1, 16) lr = trail.suggest_float('lr', 1e-4, 1e-2,step=0.0001) lossfunc = trail.suggest_categorical('loss', ['MSE', 'MAE']) opt = trail.suggest_categorical('opt', ['Adam', 'SGD']) hidden_layer = trail.suggest_int('hiddenlayer', 20, 1200) activefunc = trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh']) weightdekey = trail.suggest_float('weight_dekay', 0, 1,step=0.01) momentum= trail.suggest_float('momentum',0,1,step=0.01) loss = train(batchsize, lr, lossfunc, opt, hidden_layer, activefunc, weightdekey,momentum) return lossif __name__ == '__main__': st=time.time() study = optuna.create_study(study_name='test', direction='minimize') study.optimize(objective, n_trials=500) print(study.best_params) print(study.best_trial) print(study.best_trial.value) print(time.time()-st) optuna.visualization.plot_param_importances(study).show() optuna.visualization.plot_optimization_history(study).show() optuna.visualization.plot_slice(study).show()
本文链接地址:https://www.jiuchutong.com/zhishi/298419.html 转载请保留说明!

上一篇:Transformer 中的mask(transformer add norm)

下一篇:SSD训练数据集流程(学习记录)(ssd训练自己的数据集pytorch)

  • 那些家喻户晓的品牌为何还要拼命投广告?(那些家喻户晓的孩子)

    那些家喻户晓的品牌为何还要拼命投广告?(那些家喻户晓的孩子)

  • 苹果平板怎么开双屏模式(苹果平板怎么开不了机)

    苹果平板怎么开双屏模式(苹果平板怎么开不了机)

  • 抖音直播点亮是什么意思呢(抖音直播点亮是怎么回事)

    抖音直播点亮是什么意思呢(抖音直播点亮是怎么回事)

  • 华为手机微信怎么分身(华为手机微信怎么设置密码锁)

    华为手机微信怎么分身(华为手机微信怎么设置密码锁)

  • 对方加你微信过期了怎么办(对方加微信过来还没认证不小心删了丛那里找回来)

    对方加你微信过期了怎么办(对方加微信过来还没认证不小心删了丛那里找回来)

  • 计算机病毒寄生方式(计算机病毒寄生方式分类)

    计算机病毒寄生方式(计算机病毒寄生方式分类)

  • macbook如何新建word文档(macbook如何新建桌面)

    macbook如何新建word文档(macbook如何新建桌面)

  • 拼多多新店刷流量应注意哪些细节?(拼多多刷销量流程)

    拼多多新店刷流量应注意哪些细节?(拼多多刷销量流程)

  • 支付宝余额宝冻结的金额怎么拿出来?(支付宝余额宝冻结资金是怎么回事?)

    支付宝余额宝冻结的金额怎么拿出来?(支付宝余额宝冻结资金是怎么回事?)

  • mate30后面四个孔分别是什么

    mate30后面四个孔分别是什么

  • 微信信息没有提示怎么回事(微信发了消息20分钟想撤回)

    微信信息没有提示怎么回事(微信发了消息20分钟想撤回)

  • 苹果手机发语音有杂音怎么回事(苹果手机发语音系统错误怎么办)

    苹果手机发语音有杂音怎么回事(苹果手机发语音系统错误怎么办)

  • vivo的语音助手叫什么(怎么打开vivo的语音助手)

    vivo的语音助手叫什么(怎么打开vivo的语音助手)

  • 为什么软件下载后不在手机上显示(为什么软件下载到d盘,c盘内存减少了呢)

    为什么软件下载后不在手机上显示(为什么软件下载到d盘,c盘内存减少了呢)

  • 多多果园怎么删除好友(多多果园怎么删除果树)

    多多果园怎么删除好友(多多果园怎么删除果树)

  • 华为m6可以用鼠标吗(华为m6电脑模式可以用鼠标吗)

    华为m6可以用鼠标吗(华为m6电脑模式可以用鼠标吗)

  • 荣耀怎么设置备忘录提醒(荣耀怎么设置备用密码)

    荣耀怎么设置备忘录提醒(荣耀怎么设置备用密码)

  • 手机丢失支付宝如何挂失(手机丢失支付宝怎么冻结账户)

    手机丢失支付宝如何挂失(手机丢失支付宝怎么冻结账户)

  • 华为售后检测要拆机吗(华为售后检测要带什么)

    华为售后检测要拆机吗(华为售后检测要带什么)

  • 京东消息提醒怎么取消(京东消息提醒怎么关掉)

    京东消息提醒怎么取消(京东消息提醒怎么关掉)

  • 全民k歌怎么玩(全民k歌怎么玩游戏)

    全民k歌怎么玩(全民k歌怎么玩游戏)

  • 快手怎么卡点(快手怎么卡点视频)

    快手怎么卡点(快手怎么卡点视频)

  • 销售货物收入确认条件
  • 电子税务局没有发票开具
  • 税额为零的增值税是多少
  • 应收账款计入借方贷方
  • 房地产预缴增值税是含税还是不含税
  • 固定资产累计折旧是什么科目
  • 自然人股权转让的纳税筹划
  • 房地产企业土地使用税
  • 增值税专票怎么抵扣
  • 政府补贴项目账务怎么做
  • 存货取得长期股权投资
  • 收银系统已入库怎么操作
  • 增资扩股如何操作
  • 事业单位发生管理费用
  • 政府扶持资金是什么意思
  • 一般纳税人转出进项税额
  • 小规模企业的企业所得税怎么交
  • 分公司哪些税需要交
  • 单位为员工缴纳社保分录
  • 反向吸收合并账务处理
  • 准予抵扣的进项税额有哪些
  • 年末计提银行借款利息
  • 房地产中介公司排名
  • 怎么计算研发费用占销售收入总额比例
  • 所得税汇算清缴补税的会计处理
  • 无法访问您可能没有权限使用资源
  • 职工教育经费是工资总额的多少
  • linux 判断语句
  • 招标场地费怎么收
  • 赡养老人支出如果有四个子女都要填吗
  • PHP:oci_commit()的用法_Oracle函数
  • 出纳借方
  • 广告公司的工程师好做吗
  • pavsrv50.exe - pavsrv50进程管理信息
  • 通过二手车买进套现
  • php的魔术函数
  • 资产减值会计处理论文
  • 企业合并发生的审计费用,评估费用会计分录
  • 转让旧固定资产怎么做账
  • 人力资源外包可以去吗
  • 房租违约金怎么开发票
  • redis两种持久化方式的优缺点
  • 基建账并账规定
  • vue 路由
  • php执行系统命令函数
  • 公司租赁个人车辆怎么开发票
  • 所得税弥补亏损年限10
  • 收到个人所得税汇算清缴短信
  • 贴现法付息的实际利息
  • 稳岗返还资金最新账务处理
  • 个体工商户公帐转法人私人账户
  • 汽修修理厂
  • 印花税缴款了发票怎么查
  • 业务招待费扣除标准是多少
  • 代开发票取得的收入如何入账?
  • 研发费用账务处理实例
  • 为什么规定视同销售?
  • 开发软件应采用
  • 企业所得税预缴2‰
  • sqlmd5加密后解密
  • win2003 安装iis
  • windows modules installer占用磁盘高
  • win8还能用吗
  • window10系统邮件设置在哪里
  • 在windows操作中
  • 如何关掉数据
  • elccest.exe是间谍广告程序吗 elccest进程有什么作用
  • cortanawin10在哪
  • ubuntu 无法正常启动
  • Linux一键安装ftp
  • centos6启动服务的命令
  • windows10version20h2的03
  • win7系统引导坏了怎么修复
  • windows10一分钟重启解决
  • python函数详解
  • 从《AndEngine游戏开发实践指南》开始,学习AndEngine引擎
  • Android自定义系统服务框架
  • 安庆税务局窗口电话
  • 印花税在哪里查询
  • 税务机关垂直领导
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设