位置: IT常识 - 正文

超参数调优框架optuna(可配合pytorch)(超参数设置)

编辑:rootadmin
超参数调优框架optuna(可配合pytorch) 目录前言一、optuna的使用流程二、结果可视化三、pytorch代码使用optuna前言

推荐整理分享超参数调优框架optuna(可配合pytorch)(超参数设置),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:超参数设置,模型超参数调整,参数调优方法,超参数调优的作用,超参数选择,超参数调节,超参数优化算法,超参数优化,内容如对您有帮助,希望把文章链接给更多的朋友!

在深度学习快速发展的今天,对于不同深度学习模型的超参数优化(hyperparameter optimization),始终是一个比较头痛的问题。在超参较少的情况下,grid search是比较常见的方式,但是随着超参数量的不断增多,特别是对于神经网络而言,训练过程的超参和NN本身的超参组成的参数空间是巨大的,grid search方法会消耗巨大的资源,而且效果很差,因此寻找一个“机器炼丹”的框架十分必要。

optuna 是一个十分常用的超参数调优框架,具有操作简单,嵌入式强和动态调整参数空间等优点。另外还有其他框架也可以进行超参优化,如李沐老师提到的automl等。

一、optuna的使用流程

首先需要在命令行 pip install optuna 载入这个第三方库,载入之后import即可。

optuna中需要注意几个关键的名词: trail::一次实验 study::一次学习过程(包括多次实验)

import optunadef obj(trail):x = trail.suggest_float('x',1,5)return (x-3)*(x-3)stu = optuna.creat_study(study_name = 'test', direction = 'minimize')stu.optimize(obj, n_trials = 50)print(study.best_params)print(study.best_trial)print(study.best_trial.value)超参数调优框架optuna(可配合pytorch)(超参数设置)

该段实例代码中,函数obj定义一个含参数的需要优化的模块,带调整的超参数为 ‘x’ ,返回值为该模块的 objective value。超参x的类型为float,可调整空间为 [1,5] 左右都闭区间,常用的还有suggest_int表示整型,suggest_categorical表示字符串集合。

trail.suggest_int('name', 10, 50)trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh'])

study表示一个学习过程,direction参数为“minimize”表示对函数obj 的返回值(同时也是每次trial的objective value)向最小的方向优化。

二、结果可视化

optuna.visualization中包含了丰富的可视化工具。比较推荐使用的是以下三个:

optuna.visualization.plot_param_importances(stu).show()optuna.visualization.plot_optimization_history(stu).show()optuna.visualization.plot_slice(stu).show()

plot_param_importances 展示各个超参数对结果影响的重要性

plot_optimization_history 展示在n_trail 个trail中每次的objective value和当前的最优解

plot_slice 展示每个超参数在所有trail中取值的分布,以散点图的形式

三、pytorch代码使用optuna

在pytorch构建的MLP中进行使用,可以看到该调参框架是十分灵活的,可以设置训练参数,如batchsize,learning rate,也可也设置NN的参数,如隐藏层数目,激活函数类型等。

import torchfrom torch import nn, optimfrom torch.utils.data import DataLoaderfrom torch.autograd import Variable # 获取变量import optunadef train(batch_size, learning_rate, lossfunc, opt, hidden_layer, activefunc, weightdk,momentum): # 选出一些超参数 trainset_num = 800 testset_num = 50 train_dataset = myDataset(trainset_num) test_dataset = myDataset(testset_num) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) # 创建CNN模型, 并设置损失函数及优化器 model = MLP(hidden_layer, activefunc).cuda() # print(model) if lossfunc == 'MSE': criterion = nn.MSELoss().cuda() elif lossfunc == 'MAE': criterion = nn.L1Loss() # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weightdk) optimizer =optim.RMSprop(model.parameters(),lr=learning_rate,weight_decay=weightdk, momentum=momentum) # 训练过程 for epoch in range(num_epoches): # 训练模式 model.train() for i, data in enumerate(train_loader): inputs, labels, _ = data inputs = Variable(inputs).float().cuda() labels = Variable(labels).float().cuda() # 前向传播 out = model(inputs) # 可以考虑加正则项 train_loss = criterion(out, labels) optimizer.zero_grad() train_loss.backward() optimizer.step() model.eval() testloss = test() #返回测试集合上的MAE print('Test MAE = ', resloss) return reslossdef objective(trail): batchsize = trail.suggest_int('batchsize', 1, 16) lr = trail.suggest_float('lr', 1e-4, 1e-2,step=0.0001) lossfunc = trail.suggest_categorical('loss', ['MSE', 'MAE']) opt = trail.suggest_categorical('opt', ['Adam', 'SGD']) hidden_layer = trail.suggest_int('hiddenlayer', 20, 1200) activefunc = trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh']) weightdekey = trail.suggest_float('weight_dekay', 0, 1,step=0.01) momentum= trail.suggest_float('momentum',0,1,step=0.01) loss = train(batchsize, lr, lossfunc, opt, hidden_layer, activefunc, weightdekey,momentum) return lossif __name__ == '__main__': st=time.time() study = optuna.create_study(study_name='test', direction='minimize') study.optimize(objective, n_trials=500) print(study.best_params) print(study.best_trial) print(study.best_trial.value) print(time.time()-st) optuna.visualization.plot_param_importances(study).show() optuna.visualization.plot_optimization_history(study).show() optuna.visualization.plot_slice(study).show()
本文链接地址:https://www.jiuchutong.com/zhishi/298419.html 转载请保留说明!

上一篇:Transformer 中的mask(transformer add norm)

下一篇:SSD训练数据集流程(学习记录)(ssd训练自己的数据集pytorch)

  • 上海公交车可以微信支付吗(上海公交车可以刷两次吗)

    上海公交车可以微信支付吗(上海公交车可以刷两次吗)

  • ios14轻点背面支持什么机型(ios14.1轻点背面)

    ios14轻点背面支持什么机型(ios14.1轻点背面)

  • 如何连接网络打印机(如何连接网络打印机具体步骤)

    如何连接网络打印机(如何连接网络打印机具体步骤)

  • 腾讯会议流量消耗大吗(腾讯会议流量消耗)

    腾讯会议流量消耗大吗(腾讯会议流量消耗)

  • 升级miui12会清除数据吗(升级miui12后还能退回去吗)

    升级miui12会清除数据吗(升级miui12后还能退回去吗)

  • 电脑开机显示无信号然后黑屏怎么回事(电脑开机显示无驱动器怎么办)

    电脑开机显示无信号然后黑屏怎么回事(电脑开机显示无驱动器怎么办)

  • 蓝牙耳机摔开了合不上(蓝牙耳机摔开了装回去了但是却不吻合)

    蓝牙耳机摔开了合不上(蓝牙耳机摔开了装回去了但是却不吻合)

  • 拼多多为什么微信支付不了(拼多多为什么微信支付宝都能用)

    拼多多为什么微信支付不了(拼多多为什么微信支付宝都能用)

  • 微信有人投诉可以查到是谁投诉的吗(别人微信投诉我,我会收到提示吗)

    微信有人投诉可以查到是谁投诉的吗(别人微信投诉我,我会收到提示吗)

  • 抖音动态壁纸下载失败怎么回事(抖音动态壁纸下载到哪了)

    抖音动态壁纸下载失败怎么回事(抖音动态壁纸下载到哪了)

  • 腾讯课堂老师可以看学生的摄像头吗(腾讯课堂老师可以看到学生分屏在干什么吗)

    腾讯课堂老师可以看学生的摄像头吗(腾讯课堂老师可以看到学生分屏在干什么吗)

  • 小米手环4续航能力(小米手环4续航不足一天)

    小米手环4续航能力(小米手环4续航不足一天)

  • 盲插是什么意思(盲插接口又叫什么)

    盲插是什么意思(盲插接口又叫什么)

  • 浏览公众号对方知道吗(浏览公众号对方会知道吗)

    浏览公众号对方知道吗(浏览公众号对方会知道吗)

  • 小米常驻通知是什么意思(小米设置常驻通知)

    小米常驻通知是什么意思(小米设置常驻通知)

  • 面容识别坏了可以修吗(面容识别坏了怎么回事)

    面容识别坏了可以修吗(面容识别坏了怎么回事)

  • vue怎么拼图片和视频(vue图片叠加显示)

    vue怎么拼图片和视频(vue图片叠加显示)

  • soul拉黑了怎么加回来(soul拉黑了怎么分享名片)

    soul拉黑了怎么加回来(soul拉黑了怎么分享名片)

  • 显示器acin是插什么的(显示器ac和dc电源)

    显示器acin是插什么的(显示器ac和dc电源)

  • 抖音播放为0怎么解决(抖音播放0怎么回事)

    抖音播放为0怎么解决(抖音播放0怎么回事)

  • 夏普电视怎么连接手机(夏普电视怎么连接机顶盒)

    夏普电视怎么连接手机(夏普电视怎么连接机顶盒)

  • 怎么安装WIN7系统?(怎么安装win7系统后怎么安装驱动)

    怎么安装WIN7系统?(怎么安装win7系统后怎么安装驱动)

  • 如何选择一款适合自己的家用路由器(如何选择一款适合自己家庭的凉席)

    如何选择一款适合自己的家用路由器(如何选择一款适合自己家庭的凉席)

  • OK源码中国首发微擎破解模块首页主题永和自适应代理首页v9.1.3-OK源码破解(okr开源软件)

    OK源码中国首发微擎破解模块首页主题永和自适应代理首页v9.1.3-OK源码破解(okr开源软件)

  • 工资表个税多扣了账务处理递减
  • 小规模增值税会计处理流程
  • 企业计提印花税会计处理
  • 税务现金流量表怎么填
  • 建筑行业增值税税负率一般控制在多少合适
  • 出口的港杂费包括哪些
  • 企业所得税清算报备表清算结束日
  • 分支机构是不是需要设立账簿
  • 增值税减免备案改备查后续管理
  • 制造费用明细账实例图
  • 制造费用包括哪三类
  • 供应链公司的组织架构图
  • 合并利润表抵消事项包括
  • 什么是企业所得税收入
  • 公司怎么给个人开票
  • 增值税抵扣凭证包括桥闸通行费发票
  • 税收名词汇编
  • 销售退回怎么开票
  • 可供出售债券投资
  • 支票作废了需要什么材料
  • 无偿转让股权需要交什么税
  • 周转材料月末有余额吗
  • 小规模纳税人开票限额是多少
  • 关于幼儿园会没课程的会刊
  • cpu天梯图2022最新版1240p
  • 苹果六微信
  • hypertrm.exe系统错误
  • 票据融资都有哪些方式
  • 销户余额转出总公司怎样记账
  • 产供销一体化什么意思
  • ChatGLM-6B (介绍相关概念、基础环境搭建及部署)
  • 非货币性资产交换
  • 购买免税农产品可以抵扣进项税
  • uniapp动态设置标题
  • 以前年度损益调整结转到哪里
  • php获取北京时间
  • 前后端分离与不分离
  • 长期借款利息的账务处理涉及的会计科目有
  • 智能优化算法书籍推荐
  • web主要的请求方式有几种
  • PHP中set_include_path()函数相关用法分析
  • php代理访问
  • 外购存货的成本包括哪些内容
  • 如何减税降税
  • phpcms迁移
  • 未开票收入缴纳增值税吗
  • 水运企业会计核算办法
  • 不征税收入和免税收入有哪些项目
  • 关联企业债资比怎么计算
  • 申请电子发票需要什么条件
  • 资产负债表期末余额是累计数吗
  • 营业总收入包含什么
  • 保险补偿多久到账
  • 扣员工工作服费用合法吗?
  • 现金流量表利息支出
  • 基层工会经费收入来源包括
  • 专项应付款能转出吗
  • 汽车配件属于什么业务类型
  • 应付利息的主要方式
  • 小微企业取得的进项税能不能抵扣
  • 小规模公司怎样添加员工
  • 支付劳务费需要什么原始凭证
  • 镜的镜像截图
  • wysafe.exe是什么
  • xp系统ie浏览器怎么升级
  • linux恢复rm删除目录
  • Win7登录密码
  • 获取windows的最新信息
  • python基本入门
  • python引用方法
  • sockaddr_in和sockaddr
  • nodejs重命名文件
  • python 字典的字典
  • javascript 基础
  • js foreach倒序
  • javascript substring的用法
  • python代码规范化
  • 银行开业送什么花
  • 企业所得税账务如何处理
  • 微信如何查询个人名下所有银行卡
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设