位置: IT常识 - 正文

超参数调优框架optuna(可配合pytorch)(超参数设置)

编辑:rootadmin
超参数调优框架optuna(可配合pytorch) 目录前言一、optuna的使用流程二、结果可视化三、pytorch代码使用optuna前言

推荐整理分享超参数调优框架optuna(可配合pytorch)(超参数设置),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:超参数设置,模型超参数调整,参数调优方法,超参数调优的作用,超参数选择,超参数调节,超参数优化算法,超参数优化,内容如对您有帮助,希望把文章链接给更多的朋友!

在深度学习快速发展的今天,对于不同深度学习模型的超参数优化(hyperparameter optimization),始终是一个比较头痛的问题。在超参较少的情况下,grid search是比较常见的方式,但是随着超参数量的不断增多,特别是对于神经网络而言,训练过程的超参和NN本身的超参组成的参数空间是巨大的,grid search方法会消耗巨大的资源,而且效果很差,因此寻找一个“机器炼丹”的框架十分必要。

optuna 是一个十分常用的超参数调优框架,具有操作简单,嵌入式强和动态调整参数空间等优点。另外还有其他框架也可以进行超参优化,如李沐老师提到的automl等。

一、optuna的使用流程

首先需要在命令行 pip install optuna 载入这个第三方库,载入之后import即可。

optuna中需要注意几个关键的名词: trail::一次实验 study::一次学习过程(包括多次实验)

import optunadef obj(trail):x = trail.suggest_float('x',1,5)return (x-3)*(x-3)stu = optuna.creat_study(study_name = 'test', direction = 'minimize')stu.optimize(obj, n_trials = 50)print(study.best_params)print(study.best_trial)print(study.best_trial.value)超参数调优框架optuna(可配合pytorch)(超参数设置)

该段实例代码中,函数obj定义一个含参数的需要优化的模块,带调整的超参数为 ‘x’ ,返回值为该模块的 objective value。超参x的类型为float,可调整空间为 [1,5] 左右都闭区间,常用的还有suggest_int表示整型,suggest_categorical表示字符串集合。

trail.suggest_int('name', 10, 50)trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh'])

study表示一个学习过程,direction参数为“minimize”表示对函数obj 的返回值(同时也是每次trial的objective value)向最小的方向优化。

二、结果可视化

optuna.visualization中包含了丰富的可视化工具。比较推荐使用的是以下三个:

optuna.visualization.plot_param_importances(stu).show()optuna.visualization.plot_optimization_history(stu).show()optuna.visualization.plot_slice(stu).show()

plot_param_importances 展示各个超参数对结果影响的重要性

plot_optimization_history 展示在n_trail 个trail中每次的objective value和当前的最优解

plot_slice 展示每个超参数在所有trail中取值的分布,以散点图的形式

三、pytorch代码使用optuna

在pytorch构建的MLP中进行使用,可以看到该调参框架是十分灵活的,可以设置训练参数,如batchsize,learning rate,也可也设置NN的参数,如隐藏层数目,激活函数类型等。

import torchfrom torch import nn, optimfrom torch.utils.data import DataLoaderfrom torch.autograd import Variable # 获取变量import optunadef train(batch_size, learning_rate, lossfunc, opt, hidden_layer, activefunc, weightdk,momentum): # 选出一些超参数 trainset_num = 800 testset_num = 50 train_dataset = myDataset(trainset_num) test_dataset = myDataset(testset_num) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) # 创建CNN模型, 并设置损失函数及优化器 model = MLP(hidden_layer, activefunc).cuda() # print(model) if lossfunc == 'MSE': criterion = nn.MSELoss().cuda() elif lossfunc == 'MAE': criterion = nn.L1Loss() # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weightdk) optimizer =optim.RMSprop(model.parameters(),lr=learning_rate,weight_decay=weightdk, momentum=momentum) # 训练过程 for epoch in range(num_epoches): # 训练模式 model.train() for i, data in enumerate(train_loader): inputs, labels, _ = data inputs = Variable(inputs).float().cuda() labels = Variable(labels).float().cuda() # 前向传播 out = model(inputs) # 可以考虑加正则项 train_loss = criterion(out, labels) optimizer.zero_grad() train_loss.backward() optimizer.step() model.eval() testloss = test() #返回测试集合上的MAE print('Test MAE = ', resloss) return reslossdef objective(trail): batchsize = trail.suggest_int('batchsize', 1, 16) lr = trail.suggest_float('lr', 1e-4, 1e-2,step=0.0001) lossfunc = trail.suggest_categorical('loss', ['MSE', 'MAE']) opt = trail.suggest_categorical('opt', ['Adam', 'SGD']) hidden_layer = trail.suggest_int('hiddenlayer', 20, 1200) activefunc = trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh']) weightdekey = trail.suggest_float('weight_dekay', 0, 1,step=0.01) momentum= trail.suggest_float('momentum',0,1,step=0.01) loss = train(batchsize, lr, lossfunc, opt, hidden_layer, activefunc, weightdekey,momentum) return lossif __name__ == '__main__': st=time.time() study = optuna.create_study(study_name='test', direction='minimize') study.optimize(objective, n_trials=500) print(study.best_params) print(study.best_trial) print(study.best_trial.value) print(time.time()-st) optuna.visualization.plot_param_importances(study).show() optuna.visualization.plot_optimization_history(study).show() optuna.visualization.plot_slice(study).show()
本文链接地址:https://www.jiuchutong.com/zhishi/298419.html 转载请保留说明!

上一篇:Transformer 中的mask(transformer add norm)

下一篇:SSD训练数据集流程(学习记录)(ssd训练自己的数据集pytorch)

  • 三星电子锁怎么改密码(三星电子锁怎么看有没有电)

    三星电子锁怎么改密码(三星电子锁怎么看有没有电)

  • 华为荣耀8可以用手机u盘吗(华为荣耀8可以升级鸿蒙系统吗)

    华为荣耀8可以用手机u盘吗(华为荣耀8可以升级鸿蒙系统吗)

  • 一个手机能注册几个微信号(一个手机号可以注册几个淘宝号)

    一个手机能注册几个微信号(一个手机号可以注册几个淘宝号)

  • 苹果18w是几v几a(iphone 18w)

    苹果18w是几v几a(iphone 18w)

  • 快手b类行为都有哪些(快手b类违规是怎么产生的?有人举报吗?)

    快手b类行为都有哪些(快手b类违规是怎么产生的?有人举报吗?)

  • 微信投票过于频繁 请稍后重试(微信投票过于频繁会怎样)

    微信投票过于频繁 请稍后重试(微信投票过于频繁会怎样)

  • macpro截图快捷键(macbookpro截图快捷键 去哪了)

    macpro截图快捷键(macbookpro截图快捷键 去哪了)

  • 电话传播声音的原理(电话传声的原理)

    电话传播声音的原理(电话传声的原理)

  • 抖音直播为什么没人进来(抖音直播为什么小孩不能出现在画面里)

    抖音直播为什么没人进来(抖音直播为什么小孩不能出现在画面里)

  • version是什么版本(version版本)

    version是什么版本(version版本)

  • 苹果大陆版和美版有什么区别(苹果大陆版和美国版区别)

    苹果大陆版和美版有什么区别(苹果大陆版和美国版区别)

  • 锁屏声音怎么调大(锁屏声音怎么调节)

    锁屏声音怎么调大(锁屏声音怎么调节)

  • 来电anonymous怎么解决(来电anonymous在哪里设置)

    来电anonymous怎么解决(来电anonymous在哪里设置)

  • 剪映里的蒙版怎么用(剪映里的蒙版怎么弄)

    剪映里的蒙版怎么用(剪映里的蒙版怎么弄)

  • 华为怎么共享热点网络连接(华为手机共享热点在哪里设置)

    华为怎么共享热点网络连接(华为手机共享热点在哪里设置)

  • 抖音红包提现不了怎么回事(抖音红包提现不出来)

    抖音红包提现不了怎么回事(抖音红包提现不出来)

  • 微信人脸验证老是失败(为什么微信人脸验证一直通不过)

    微信人脸验证老是失败(为什么微信人脸验证一直通不过)

  • etc插卡显示无卡怎么办(ETC插卡显示无卡后黑屏)

    etc插卡显示无卡怎么办(ETC插卡显示无卡后黑屏)

  • fttb和ftth的不同点(ftth和fttp)

    fttb和ftth的不同点(ftth和fttp)

  • iphone7plus无服务解决办法(iphone7plus无服务维修需要多少钱)

    iphone7plus无服务解决办法(iphone7plus无服务维修需要多少钱)

  • 手机里的歌怎么传到另外的内存卡(手机里的歌怎么导入mp3)

    手机里的歌怎么传到另外的内存卡(手机里的歌怎么导入mp3)

  • 手机怎么关闭降噪功能(手机怎么关闭降噪)

    手机怎么关闭降噪功能(手机怎么关闭降噪)

  • 华为mate30有红外功能吗(华为mate30有红外线遥控吗)

    华为mate30有红外功能吗(华为mate30有红外线遥控吗)

  • soul怎么搜索别人的id(新版soul怎么搜索别人的id)

    soul怎么搜索别人的id(新版soul怎么搜索别人的id)

  • apple watch2和3的区别(apple watch 2和3有什么区别)

    apple watch2和3的区别(apple watch 2和3有什么区别)

  • 小米9se红外遥控怎么用(小米9se红外遥控怎么设置)

    小米9se红外遥控怎么用(小米9se红外遥控怎么设置)

  • 小米蓝牙耳机air怎么配对(小米蓝牙耳机air2 se连接不上)

    小米蓝牙耳机air怎么配对(小米蓝牙耳机air2 se连接不上)

  • 初级职称经济法目录
  • 金税四期什么时候全国运行
  • 企业捐赠灾区
  • 减免企业所得税怎么算
  • 小规模纳税人增值税起征点
  • 工业企业的三个阶段
  • 现金流量表年报期末现金余额
  • 供应商转让合同
  • 暂扣员工工资怎么做账
  • 转让专利技术使用权属于什么收入
  • 增值税普通发票需要交税吗
  • 没有缴纳契税
  • 纳税的税种有哪些
  • 航天服务费530是什么?
  • 实收资本印花税最新规定
  • 增值税减免税申报明细表免税代码和名称
  • 在建工程计提减值准备可以转回吗
  • 为什么红字信息查不到
  • 一般纳税人何种情况不需要交附加税
  • 暂估在建工程会计科目
  • 银行回单箱费会扣吗
  • 公司在银行购买金币没有发票
  • 工资一直计提但是未发有影响吗
  • 设备的折旧率是什么意思
  • 超市发购物卡给员工会计分录
  • macos catalina与macos big区别
  • window10进程
  • PHPfor循环语句10的阶乘
  • 罚款是否需要开发票
  • 怎么登记总分类账簿
  • 财务费用为什么增加
  • 利息支出属于生产成本吗
  • 蚊子叮咬怎么办手抄报
  • 真三国在哪下载
  • vue引入echarts柱状图
  • 工业企业库存商品的初始入账成本有
  • php如何生成html
  • 进货成本价是什么
  • eclipse php wamp配置教程
  • nerf 入门
  • 图像自动生成
  • js浅拷贝和深拷贝的方法
  • 训练集验证集和测试集
  • 微信小程序解锁安全吗
  • 微信多开使用方法
  • smitty命令用法
  • 一个残疾证一年单位免多少税2023
  • 预付卡销售可以报销吗
  • java 通配符
  • 管家婆付款单凭证科目如何修改
  • 第二季度所得税可以弥补以前年度亏损吗
  • phpcms二次开发教程
  • 进项税额漏报处理办法
  • 税务怎么认定虚列工资
  • 建筑行业小规模纳税人和一般纳税人
  • 主办会计与往来会计区别
  • 什么时候开始取卵
  • 税审报告需要什么资料
  • 本月没有认证的发票怎么做账
  • 事业单位装修费账务处理
  • 企业代理社保
  • mysql数据库主从数据不一致
  • linux命令实现
  • mysql 全量备份
  • win8桌面图标不显示
  • iis怎么打开项目
  • 安装solaris11
  • 虚拟机linux使用
  • linux安装编译工具
  • windows8怎么设置锁屏密码
  • 文件夹删不掉显示另一个程序打开
  • win7如何变快
  • javascript中的document.write
  • css网站布局实录 pdf
  • java 视频教程
  • 浙江税务网上开票流程图
  • 临沂学生医疗保险多少钱
  • 自然人扣缴端怎么申报个税
  • 请问报考国家税务局难吗
  • 如何查询有没有交医保费用
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设