位置: IT常识 - 正文

超参数调优框架optuna(可配合pytorch)(超参数设置)

编辑:rootadmin
超参数调优框架optuna(可配合pytorch) 目录前言一、optuna的使用流程二、结果可视化三、pytorch代码使用optuna前言

推荐整理分享超参数调优框架optuna(可配合pytorch)(超参数设置),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:超参数设置,模型超参数调整,参数调优方法,超参数调优的作用,超参数选择,超参数调节,超参数优化算法,超参数优化,内容如对您有帮助,希望把文章链接给更多的朋友!

在深度学习快速发展的今天,对于不同深度学习模型的超参数优化(hyperparameter optimization),始终是一个比较头痛的问题。在超参较少的情况下,grid search是比较常见的方式,但是随着超参数量的不断增多,特别是对于神经网络而言,训练过程的超参和NN本身的超参组成的参数空间是巨大的,grid search方法会消耗巨大的资源,而且效果很差,因此寻找一个“机器炼丹”的框架十分必要。

optuna 是一个十分常用的超参数调优框架,具有操作简单,嵌入式强和动态调整参数空间等优点。另外还有其他框架也可以进行超参优化,如李沐老师提到的automl等。

一、optuna的使用流程

首先需要在命令行 pip install optuna 载入这个第三方库,载入之后import即可。

optuna中需要注意几个关键的名词: trail::一次实验 study::一次学习过程(包括多次实验)

import optunadef obj(trail):x = trail.suggest_float('x',1,5)return (x-3)*(x-3)stu = optuna.creat_study(study_name = 'test', direction = 'minimize')stu.optimize(obj, n_trials = 50)print(study.best_params)print(study.best_trial)print(study.best_trial.value)超参数调优框架optuna(可配合pytorch)(超参数设置)

该段实例代码中,函数obj定义一个含参数的需要优化的模块,带调整的超参数为 ‘x’ ,返回值为该模块的 objective value。超参x的类型为float,可调整空间为 [1,5] 左右都闭区间,常用的还有suggest_int表示整型,suggest_categorical表示字符串集合。

trail.suggest_int('name', 10, 50)trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh'])

study表示一个学习过程,direction参数为“minimize”表示对函数obj 的返回值(同时也是每次trial的objective value)向最小的方向优化。

二、结果可视化

optuna.visualization中包含了丰富的可视化工具。比较推荐使用的是以下三个:

optuna.visualization.plot_param_importances(stu).show()optuna.visualization.plot_optimization_history(stu).show()optuna.visualization.plot_slice(stu).show()

plot_param_importances 展示各个超参数对结果影响的重要性

plot_optimization_history 展示在n_trail 个trail中每次的objective value和当前的最优解

plot_slice 展示每个超参数在所有trail中取值的分布,以散点图的形式

三、pytorch代码使用optuna

在pytorch构建的MLP中进行使用,可以看到该调参框架是十分灵活的,可以设置训练参数,如batchsize,learning rate,也可也设置NN的参数,如隐藏层数目,激活函数类型等。

import torchfrom torch import nn, optimfrom torch.utils.data import DataLoaderfrom torch.autograd import Variable # 获取变量import optunadef train(batch_size, learning_rate, lossfunc, opt, hidden_layer, activefunc, weightdk,momentum): # 选出一些超参数 trainset_num = 800 testset_num = 50 train_dataset = myDataset(trainset_num) test_dataset = myDataset(testset_num) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) # 创建CNN模型, 并设置损失函数及优化器 model = MLP(hidden_layer, activefunc).cuda() # print(model) if lossfunc == 'MSE': criterion = nn.MSELoss().cuda() elif lossfunc == 'MAE': criterion = nn.L1Loss() # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weightdk) optimizer =optim.RMSprop(model.parameters(),lr=learning_rate,weight_decay=weightdk, momentum=momentum) # 训练过程 for epoch in range(num_epoches): # 训练模式 model.train() for i, data in enumerate(train_loader): inputs, labels, _ = data inputs = Variable(inputs).float().cuda() labels = Variable(labels).float().cuda() # 前向传播 out = model(inputs) # 可以考虑加正则项 train_loss = criterion(out, labels) optimizer.zero_grad() train_loss.backward() optimizer.step() model.eval() testloss = test() #返回测试集合上的MAE print('Test MAE = ', resloss) return reslossdef objective(trail): batchsize = trail.suggest_int('batchsize', 1, 16) lr = trail.suggest_float('lr', 1e-4, 1e-2,step=0.0001) lossfunc = trail.suggest_categorical('loss', ['MSE', 'MAE']) opt = trail.suggest_categorical('opt', ['Adam', 'SGD']) hidden_layer = trail.suggest_int('hiddenlayer', 20, 1200) activefunc = trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh']) weightdekey = trail.suggest_float('weight_dekay', 0, 1,step=0.01) momentum= trail.suggest_float('momentum',0,1,step=0.01) loss = train(batchsize, lr, lossfunc, opt, hidden_layer, activefunc, weightdekey,momentum) return lossif __name__ == '__main__': st=time.time() study = optuna.create_study(study_name='test', direction='minimize') study.optimize(objective, n_trials=500) print(study.best_params) print(study.best_trial) print(study.best_trial.value) print(time.time()-st) optuna.visualization.plot_param_importances(study).show() optuna.visualization.plot_optimization_history(study).show() optuna.visualization.plot_slice(study).show()
本文链接地址:https://www.jiuchutong.com/zhishi/298419.html 转载请保留说明!

上一篇:Transformer 中的mask(transformer add norm)

下一篇:SSD训练数据集流程(学习记录)(ssd训练自己的数据集pytorch)

  • 华为手机怎么进入recovery模式(华为手机怎么进入简易模式)

    华为手机怎么进入recovery模式(华为手机怎么进入简易模式)

  • 硬盘坏了能修吗(电脑硬盘坏了能修吗)

    硬盘坏了能修吗(电脑硬盘坏了能修吗)

  • WPS艺术字高宽怎么设置(wps艺术字高度)

    WPS艺术字高宽怎么设置(wps艺术字高度)

  • 新浪微博注销了会怎样(新浪微博注销了还能重新申请吗)

    新浪微博注销了会怎样(新浪微博注销了还能重新申请吗)

  • 电脑找不到苹果手机热点怎么办(电脑找不到苹果手机的个人热点)

    电脑找不到苹果手机热点怎么办(电脑找不到苹果手机的个人热点)

  • 打印机adf盖在哪里(打印机adf盖传感器复位图解)

    打印机adf盖在哪里(打印机adf盖传感器复位图解)

  • 苹果电话卡怎么放进去(苹果电话卡怎么设置主卡副卡)

    苹果电话卡怎么放进去(苹果电话卡怎么设置主卡副卡)

  • 苹果长时间没用 关机了现在充电开不了(苹果长时间没用怎么恢复)

    苹果长时间没用 关机了现在充电开不了(苹果长时间没用怎么恢复)

  • nova7怎么截屏(nova7怎么截屏保)

    nova7怎么截屏(nova7怎么截屏保)

  • 电脑清灰多久一次(电脑清灰大约多久)

    电脑清灰多久一次(电脑清灰大约多久)

  • 手机能设置不收短信功能吗(手机设置不收彩信)

    手机能设置不收短信功能吗(手机设置不收彩信)

  • hdmi1.4和2.0的插头区别(hdmi1.4和2.0接口的区别)

    hdmi1.4和2.0的插头区别(hdmi1.4和2.0接口的区别)

  • vivox7能不能设置应用分身(vivoy70s怎么设置)

    vivox7能不能设置应用分身(vivoy70s怎么设置)

  • 微信支付首选怎么设置(微信支付 首选)

    微信支付首选怎么设置(微信支付 首选)

  • 手机上怎么看街景(手机怎么看街道监控)

    手机上怎么看街景(手机怎么看街道监控)

  • 小米盒子不能进去应用商店怎么办(小米盒子不能进入主界面怎么办)

    小米盒子不能进去应用商店怎么办(小米盒子不能进入主界面怎么办)

  • 酷我音乐怎么录歌(酷我音乐怎么录音发到微信上)

    酷我音乐怎么录歌(酷我音乐怎么录音发到微信上)

  • miui10使用记录怎么查(miui10如何看使用记录)

    miui10使用记录怎么查(miui10如何看使用记录)

  • 新开传世,产品介绍(传世新传)

    新开传世,产品介绍(传世新传)

  • 基本数据类型所占字节(基本数据类型所占空间)

    基本数据类型所占字节(基本数据类型所占空间)

  • 手机储存内存影响速度吗(手机内存影响性能吗)

    手机储存内存影响速度吗(手机内存影响性能吗)

  • 小米手环4qq为什么不提示消息(小米手环4不推送qq)

    小米手环4qq为什么不提示消息(小米手环4不推送qq)

  • win10头像怎么清除(win10改头像怎么删除以前的头像)

    win10头像怎么清除(win10改头像怎么删除以前的头像)

  • vue解决Not allowed to load local resource(vue解决跨域问题)

    vue解决Not allowed to load local resource(vue解决跨域问题)

  • 应用程序无法正常启动0xc0150002解决方法(应用程序无法正常启动0xc000007b)

    应用程序无法正常启动0xc0150002解决方法(应用程序无法正常启动0xc000007b)

  • yolov5加入CBAM,SE,CA,ECA注意力机制,纯代码(22.3.1还更新)(yolov5加入注意力机制后网络后进行剪枝)

    yolov5加入CBAM,SE,CA,ECA注意力机制,纯代码(22.3.1还更新)(yolov5加入注意力机制后网络后进行剪枝)

  • jwhois命令  whois 客户端服务(命令who的含义)

    jwhois命令 whois 客户端服务(命令who的含义)

  • 营改增抵减的销项税额会计分录
  • 补缴的社保可以报销吗
  • 土地价款抵扣增值税
  • 我国增值税征收范围
  • 增值税退税如何做账
  • 营业外收入交企业所得税可以扣除成本么
  • 资产负债表的货币资金根据什么填
  • 缴纳销项税额要交税吗
  • 代扣代缴个税对企业所得税的影响
  • 吊车租赁可以开6个点专票吗
  • 调整增值税误差的原因
  • 转租的门面怎么办营业执照
  • 投资子公司的现金流量
  • 发票丢失可以冲销吗
  • 服务费专票普票
  • 增值税普通发票申报
  • 不动产有法律效力吗
  • 什么时候抵扣增值税
  • 建筑行业印花税税率
  • 预提厂房租金
  • 环卫公司增值税税率
  • 进口增值税可以抵扣销项税额吗
  • 反避税的意义
  • 公司车辆违章怎么办
  • 贷款利息 发票
  • 应收账款和预收账款有什么区别
  • 存出保证金计入货币资金吗
  • win10新装系统我的电脑在哪
  • 企业若需要给客户交税
  • 固定资产增值税税率
  • php的正则表达式
  • php获取变量长度
  • 竣工结算审计费计入什么科目
  • 比弗利山庄安全吗治安
  • php编程计算日期怎么算
  • cookie与session的作用和原理
  • open api平台
  • 微服务框架图
  • php自动加载函数
  • 给最爱的他
  • vue3+ts+MicroApp实战教程
  • mysql中文乱码怎样用代码解决
  • 运输公司燃油费占比
  • 净资产包含哪些方面
  • 占地面积法如何分摊土地成本
  • 劳动仲裁经济补偿金写多了
  • 以前年度损益影响当期损益吗
  • 公司名下的车怎样领免检标志
  • 购买软件使用权计入无形资产吗
  • 一般销售商品收入怎么算
  • 铁路运费印花税怎么算
  • 工厂厨房厨具
  • 待摊费用科目分录
  • 利息支出应计入什么科目
  • 怎么理解什么是生命
  • 坏账准备确认坏账
  • 当月销售次月开票成本怎么结转
  • 工厂不开票怎么办
  • 小规模印花税怎么报
  • 货币资金核算内容
  • 各单位都需设置的是
  • ubuntu开启图形化界面
  • ubuntu20.04怎么安装
  • pages怎么标记
  • centos7权限管理
  • wrme.exe是什么
  • 电脑无法检测到麦克风怎么办
  • spmgr.exe - spmgr是什么进程 有什么用
  • linux使用yum
  • centos7.2安装
  • win8系统咋样
  • win7系统打开excel文件很慢或未响应
  • ExtJS Ext.MessageBox.alert()弹出对话框详解
  • net命令大全
  • unity3d游戏开发标准教程pdf
  • jquery设计模式
  • javascript for in
  • 出口免税不退税主要适用于什么情形
  • 新公司税务报到流程步骤
  • 公司税务认证
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设