位置: IT常识 - 正文

超参数调优框架optuna(可配合pytorch)(超参数设置)

编辑:rootadmin
超参数调优框架optuna(可配合pytorch) 目录前言一、optuna的使用流程二、结果可视化三、pytorch代码使用optuna前言

推荐整理分享超参数调优框架optuna(可配合pytorch)(超参数设置),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:超参数设置,模型超参数调整,参数调优方法,超参数调优的作用,超参数选择,超参数调节,超参数优化算法,超参数优化,内容如对您有帮助,希望把文章链接给更多的朋友!

在深度学习快速发展的今天,对于不同深度学习模型的超参数优化(hyperparameter optimization),始终是一个比较头痛的问题。在超参较少的情况下,grid search是比较常见的方式,但是随着超参数量的不断增多,特别是对于神经网络而言,训练过程的超参和NN本身的超参组成的参数空间是巨大的,grid search方法会消耗巨大的资源,而且效果很差,因此寻找一个“机器炼丹”的框架十分必要。

optuna 是一个十分常用的超参数调优框架,具有操作简单,嵌入式强和动态调整参数空间等优点。另外还有其他框架也可以进行超参优化,如李沐老师提到的automl等。

一、optuna的使用流程

首先需要在命令行 pip install optuna 载入这个第三方库,载入之后import即可。

optuna中需要注意几个关键的名词: trail::一次实验 study::一次学习过程(包括多次实验)

import optunadef obj(trail):x = trail.suggest_float('x',1,5)return (x-3)*(x-3)stu = optuna.creat_study(study_name = 'test', direction = 'minimize')stu.optimize(obj, n_trials = 50)print(study.best_params)print(study.best_trial)print(study.best_trial.value)超参数调优框架optuna(可配合pytorch)(超参数设置)

该段实例代码中,函数obj定义一个含参数的需要优化的模块,带调整的超参数为 ‘x’ ,返回值为该模块的 objective value。超参x的类型为float,可调整空间为 [1,5] 左右都闭区间,常用的还有suggest_int表示整型,suggest_categorical表示字符串集合。

trail.suggest_int('name', 10, 50)trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh'])

study表示一个学习过程,direction参数为“minimize”表示对函数obj 的返回值(同时也是每次trial的objective value)向最小的方向优化。

二、结果可视化

optuna.visualization中包含了丰富的可视化工具。比较推荐使用的是以下三个:

optuna.visualization.plot_param_importances(stu).show()optuna.visualization.plot_optimization_history(stu).show()optuna.visualization.plot_slice(stu).show()

plot_param_importances 展示各个超参数对结果影响的重要性

plot_optimization_history 展示在n_trail 个trail中每次的objective value和当前的最优解

plot_slice 展示每个超参数在所有trail中取值的分布,以散点图的形式

三、pytorch代码使用optuna

在pytorch构建的MLP中进行使用,可以看到该调参框架是十分灵活的,可以设置训练参数,如batchsize,learning rate,也可也设置NN的参数,如隐藏层数目,激活函数类型等。

import torchfrom torch import nn, optimfrom torch.utils.data import DataLoaderfrom torch.autograd import Variable # 获取变量import optunadef train(batch_size, learning_rate, lossfunc, opt, hidden_layer, activefunc, weightdk,momentum): # 选出一些超参数 trainset_num = 800 testset_num = 50 train_dataset = myDataset(trainset_num) test_dataset = myDataset(testset_num) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) # 创建CNN模型, 并设置损失函数及优化器 model = MLP(hidden_layer, activefunc).cuda() # print(model) if lossfunc == 'MSE': criterion = nn.MSELoss().cuda() elif lossfunc == 'MAE': criterion = nn.L1Loss() # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weightdk) optimizer =optim.RMSprop(model.parameters(),lr=learning_rate,weight_decay=weightdk, momentum=momentum) # 训练过程 for epoch in range(num_epoches): # 训练模式 model.train() for i, data in enumerate(train_loader): inputs, labels, _ = data inputs = Variable(inputs).float().cuda() labels = Variable(labels).float().cuda() # 前向传播 out = model(inputs) # 可以考虑加正则项 train_loss = criterion(out, labels) optimizer.zero_grad() train_loss.backward() optimizer.step() model.eval() testloss = test() #返回测试集合上的MAE print('Test MAE = ', resloss) return reslossdef objective(trail): batchsize = trail.suggest_int('batchsize', 1, 16) lr = trail.suggest_float('lr', 1e-4, 1e-2,step=0.0001) lossfunc = trail.suggest_categorical('loss', ['MSE', 'MAE']) opt = trail.suggest_categorical('opt', ['Adam', 'SGD']) hidden_layer = trail.suggest_int('hiddenlayer', 20, 1200) activefunc = trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh']) weightdekey = trail.suggest_float('weight_dekay', 0, 1,step=0.01) momentum= trail.suggest_float('momentum',0,1,step=0.01) loss = train(batchsize, lr, lossfunc, opt, hidden_layer, activefunc, weightdekey,momentum) return lossif __name__ == '__main__': st=time.time() study = optuna.create_study(study_name='test', direction='minimize') study.optimize(objective, n_trials=500) print(study.best_params) print(study.best_trial) print(study.best_trial.value) print(time.time()-st) optuna.visualization.plot_param_importances(study).show() optuna.visualization.plot_optimization_history(study).show() optuna.visualization.plot_slice(study).show()
本文链接地址:https://www.jiuchutong.com/zhishi/298419.html 转载请保留说明!

上一篇:Transformer 中的mask(transformer add norm)

下一篇:SSD训练数据集流程(学习记录)(ssd训练自己的数据集pytorch)

  • 微信营销究竟应该怎么做?(微信营销主要有哪些)

    微信营销究竟应该怎么做?(微信营销主要有哪些)

  • 红米k40游戏增强版充电多久充满(红米k40游戏增强版最严重缺点)

    红米k40游戏增强版充电多久充满(红米k40游戏增强版最严重缺点)

  • 桌面在c盘什么位置(桌面在c盘什么位置win10)

    桌面在c盘什么位置(桌面在c盘什么位置win10)

  • 佳能m50怎么关闭红色闪灯(佳能m50怎么关闭对焦辅助灯)

    佳能m50怎么关闭红色闪灯(佳能m50怎么关闭对焦辅助灯)

  • 嗨来电扣不扣费(嗨来电收费吗收多少钱)

    嗨来电扣不扣费(嗨来电收费吗收多少钱)

  • ai放大缩小的快捷键(ai放大缩小的快捷键 不能用)

    ai放大缩小的快捷键(ai放大缩小的快捷键 不能用)

  • wan口服务器无响应(wan口服务器无响应怎么回事)

    wan口服务器无响应(wan口服务器无响应怎么回事)

  • 大众点评可以看到访客记录吗(大众点评可以看营业执照吗)

    大众点评可以看到访客记录吗(大众点评可以看营业执照吗)

  • 微信删除该聊天是什么意思(微信删除该聊天还能恢复吗)

    微信删除该聊天是什么意思(微信删除该聊天还能恢复吗)

  • 误删qq漫游记录怎么恢复(误删qq漫游记录怎么恢复聊天记录)

    误删qq漫游记录怎么恢复(误删qq漫游记录怎么恢复聊天记录)

  • 为什么电脑qq发不了文件(为什么电脑qq发的消息手机看不见)

    为什么电脑qq发不了文件(为什么电脑qq发的消息手机看不见)

  • photoshop是一种什么软件(photoshop属于什么软件?)

    photoshop是一种什么软件(photoshop属于什么软件?)

  • 淘宝可以申请几次换货(淘宝可以申请几次价保)

    淘宝可以申请几次换货(淘宝可以申请几次价保)

  • wps删除空白列删不掉(wps删除空白行快捷键)

    wps删除空白列删不掉(wps删除空白行快捷键)

  • 钉钉打卡怎么退出公司(钉钉打卡怎么退出全员群)

    钉钉打卡怎么退出公司(钉钉打卡怎么退出全员群)

  • 苹果11怎么显示电量百分比(苹果11怎么显示网速)

    苹果11怎么显示电量百分比(苹果11怎么显示网速)

  • iphone8plus重量多少克(苹果8plus的重量)

    iphone8plus重量多少克(苹果8plus的重量)

  • 加密dns关闭有什么影响(加密dns关闭有什么好处)

    加密dns关闭有什么影响(加密dns关闭有什么好处)

  • 为什么发快手别人看不见(为什么发快手别人在关注页看不到)

    为什么发快手别人看不见(为什么发快手别人在关注页看不到)

  • 苹果手机显示3g怎么办(苹果手机显示3g怎么调成4g)

    苹果手机显示3g怎么办(苹果手机显示3g怎么调成4g)

  • Win11怎么禁用网络连接?Win11禁用网络连接方法(windows怎么禁用网络)

    Win11怎么禁用网络连接?Win11禁用网络连接方法(windows怎么禁用网络)

  • 无痛人流多少钱(无痛人流多少钱?)

    无痛人流多少钱(无痛人流多少钱?)

  • JavaScript下部分--头歌(educoder)实训作业题目及答案(javascript局部变量)

    JavaScript下部分--头歌(educoder)实训作业题目及答案(javascript局部变量)

  • 注册资本没有到位可以注销吗
  • 不缴或少缴应纳税款的处罚措施
  • 公司购买设备报告怎么写
  • 专票不抵扣认证什么意思
  • 增值税留抵税额是什么意思
  • 发票没有纳税人识别号能开吗
  • 小规模开普票多少税点
  • 物业公司收款一般多久
  • 医药零售行业 利润构成
  • 金融企业往来收入科目属于什么科
  • 本年利润,利润分配
  • 短期借款在房地产怎么算
  • 年底给职工发啥实物
  • 出口退税的城建税和教育费附加怎么算
  • 增值税的免征增值税范围
  • 未达起征点企业怎么处理
  • 长期待摊费忘记摊了怎么办
  • 发票密码区出来了一点
  • 汇算清缴之前找回来成本发票可以吗
  • 工资年终奖金扣多少税
  • 公账转公账没有发票
  • prevsrv.exe - prevsrv是什么进程 有什么用
  • 分配水电费会计分录怎么写
  • 收到别的公司对公转账往来
  • 苹果电脑记笔记
  • 电脑进不了系统怎么用u盘重装
  • 收到法院的案件款应该怎么做帐
  • 购置资产是什么财务活动
  • 为员工租赁房屋产生的租赁费可以抵扣增值税和所得税吗
  • atikdag.sys
  • 为什么linux这么受欢迎
  • file php
  • 企业合并分立
  • 灰狼算法的改进
  • 项目辅材计入什么科目
  • vue2转vue3工具
  • el-upload上传文件必传校验
  • laravel 实例
  • 用友u8反结账反记账的操作步骤
  • 应交增值税缴纳后入什么费用
  • mysql的命名规则
  • 不是公司员工差旅费可以入差旅费吗
  • 什么是指企业的市场营销活动发生影响的各种因素的总和
  • 期货风险准备金计提比例
  • 无形资产的意思是
  • 事业单位小规模纳税人增值税账务处理
  • 异地车辆登记证书怎么补办
  • 商品购进核算
  • 未开票的收入如何确认分录
  • 公司买的吃的计入什么科目
  • 企业筹建期间银行开户要求
  • 贷款服务的利息怎么算
  • 餐饮业可以开具免税发票吗
  • 什么是企业管理的基础工作
  • 会计刚开始学什么
  • 汇兑损益金额是怎么算出来的
  • OBJECTPROPERTY与sp_rename更改对象名称的介绍
  • Windows下MySQL 5.7无法启动的解决方法
  • windows8.1开机
  • win102009发布日期
  • 提升英语
  • “explorer.exe”进程文件
  • windows系统同时按下CTRL+ALT+DEL键没有弹出任务管理器的解决方法
  • win8系统怎么关掉开机密码
  • linux系统中make的用法
  • win10电脑提示
  • [置顶]电影名字《收件人不详》
  • Android setVisibility的总结~
  • 第一次接触怎么形容
  • nodejs require 路径查找
  • unity视频播放
  • unity3d赛车游戏毕业设计
  • jquery通过属性值获取元素
  • cocos2d-x安装
  • ca证书密码是什么
  • 税务稽查预警指标
  • 货物劳务税包括哪些税
  • 免抵退税办法不得抵扣的进项
  • 财政法和经济法的关系
  • 如何做好巡察组组员
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设