位置: IT常识 - 正文

超参数调优框架optuna(可配合pytorch)(超参数设置)

编辑:rootadmin
超参数调优框架optuna(可配合pytorch) 目录前言一、optuna的使用流程二、结果可视化三、pytorch代码使用optuna前言

推荐整理分享超参数调优框架optuna(可配合pytorch)(超参数设置),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:超参数设置,模型超参数调整,参数调优方法,超参数调优的作用,超参数选择,超参数调节,超参数优化算法,超参数优化,内容如对您有帮助,希望把文章链接给更多的朋友!

在深度学习快速发展的今天,对于不同深度学习模型的超参数优化(hyperparameter optimization),始终是一个比较头痛的问题。在超参较少的情况下,grid search是比较常见的方式,但是随着超参数量的不断增多,特别是对于神经网络而言,训练过程的超参和NN本身的超参组成的参数空间是巨大的,grid search方法会消耗巨大的资源,而且效果很差,因此寻找一个“机器炼丹”的框架十分必要。

optuna 是一个十分常用的超参数调优框架,具有操作简单,嵌入式强和动态调整参数空间等优点。另外还有其他框架也可以进行超参优化,如李沐老师提到的automl等。

一、optuna的使用流程

首先需要在命令行 pip install optuna 载入这个第三方库,载入之后import即可。

optuna中需要注意几个关键的名词: trail::一次实验 study::一次学习过程(包括多次实验)

import optunadef obj(trail):x = trail.suggest_float('x',1,5)return (x-3)*(x-3)stu = optuna.creat_study(study_name = 'test', direction = 'minimize')stu.optimize(obj, n_trials = 50)print(study.best_params)print(study.best_trial)print(study.best_trial.value)超参数调优框架optuna(可配合pytorch)(超参数设置)

该段实例代码中,函数obj定义一个含参数的需要优化的模块,带调整的超参数为 ‘x’ ,返回值为该模块的 objective value。超参x的类型为float,可调整空间为 [1,5] 左右都闭区间,常用的还有suggest_int表示整型,suggest_categorical表示字符串集合。

trail.suggest_int('name', 10, 50)trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh'])

study表示一个学习过程,direction参数为“minimize”表示对函数obj 的返回值(同时也是每次trial的objective value)向最小的方向优化。

二、结果可视化

optuna.visualization中包含了丰富的可视化工具。比较推荐使用的是以下三个:

optuna.visualization.plot_param_importances(stu).show()optuna.visualization.plot_optimization_history(stu).show()optuna.visualization.plot_slice(stu).show()

plot_param_importances 展示各个超参数对结果影响的重要性

plot_optimization_history 展示在n_trail 个trail中每次的objective value和当前的最优解

plot_slice 展示每个超参数在所有trail中取值的分布,以散点图的形式

三、pytorch代码使用optuna

在pytorch构建的MLP中进行使用,可以看到该调参框架是十分灵活的,可以设置训练参数,如batchsize,learning rate,也可也设置NN的参数,如隐藏层数目,激活函数类型等。

import torchfrom torch import nn, optimfrom torch.utils.data import DataLoaderfrom torch.autograd import Variable # 获取变量import optunadef train(batch_size, learning_rate, lossfunc, opt, hidden_layer, activefunc, weightdk,momentum): # 选出一些超参数 trainset_num = 800 testset_num = 50 train_dataset = myDataset(trainset_num) test_dataset = myDataset(testset_num) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) # 创建CNN模型, 并设置损失函数及优化器 model = MLP(hidden_layer, activefunc).cuda() # print(model) if lossfunc == 'MSE': criterion = nn.MSELoss().cuda() elif lossfunc == 'MAE': criterion = nn.L1Loss() # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weightdk) optimizer =optim.RMSprop(model.parameters(),lr=learning_rate,weight_decay=weightdk, momentum=momentum) # 训练过程 for epoch in range(num_epoches): # 训练模式 model.train() for i, data in enumerate(train_loader): inputs, labels, _ = data inputs = Variable(inputs).float().cuda() labels = Variable(labels).float().cuda() # 前向传播 out = model(inputs) # 可以考虑加正则项 train_loss = criterion(out, labels) optimizer.zero_grad() train_loss.backward() optimizer.step() model.eval() testloss = test() #返回测试集合上的MAE print('Test MAE = ', resloss) return reslossdef objective(trail): batchsize = trail.suggest_int('batchsize', 1, 16) lr = trail.suggest_float('lr', 1e-4, 1e-2,step=0.0001) lossfunc = trail.suggest_categorical('loss', ['MSE', 'MAE']) opt = trail.suggest_categorical('opt', ['Adam', 'SGD']) hidden_layer = trail.suggest_int('hiddenlayer', 20, 1200) activefunc = trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh']) weightdekey = trail.suggest_float('weight_dekay', 0, 1,step=0.01) momentum= trail.suggest_float('momentum',0,1,step=0.01) loss = train(batchsize, lr, lossfunc, opt, hidden_layer, activefunc, weightdekey,momentum) return lossif __name__ == '__main__': st=time.time() study = optuna.create_study(study_name='test', direction='minimize') study.optimize(objective, n_trials=500) print(study.best_params) print(study.best_trial) print(study.best_trial.value) print(time.time()-st) optuna.visualization.plot_param_importances(study).show() optuna.visualization.plot_optimization_history(study).show() optuna.visualization.plot_slice(study).show()
本文链接地址:https://www.jiuchutong.com/zhishi/298419.html 转载请保留说明!

上一篇:Transformer 中的mask(transformer add norm)

下一篇:SSD训练数据集流程(学习记录)(ssd训练自己的数据集pytorch)

  • 抖音直播开通需要费用吗(抖音直播开通需要多久)

    抖音直播开通需要费用吗(抖音直播开通需要多久)

  • 为什么抖音粉丝增加不显示(为什么抖音粉丝每天都在掉)

    为什么抖音粉丝增加不显示(为什么抖音粉丝每天都在掉)

  • 为什么电话一打就断(为什么每次打电话)

    为什么电话一打就断(为什么每次打电话)

  • 安兔兔跑分多少才算好

    安兔兔跑分多少才算好

  • win10 正在准备windows 请不要关闭(w10正在准备windows)

    win10 正在准备windows 请不要关闭(w10正在准备windows)

  • 小米手环第一次充电要多久(小米手环第一次充电)

    小米手环第一次充电要多久(小米手环第一次充电)

  • ipod touch7和6的区别(ipodtouch6跟7有什么区别)

    ipod touch7和6的区别(ipodtouch6跟7有什么区别)

  • psg1218是千兆路由器吗(psg1218 k2路由器是百兆还是千兆)

    psg1218是千兆路由器吗(psg1218 k2路由器是百兆还是千兆)

  • 抖音发言太快多久恢复(抖音发言太快了,请你控计里需要多久才能重新)

    抖音发言太快多久恢复(抖音发言太快了,请你控计里需要多久才能重新)

  • 小米8手机用打开双4g吗(小米8手机打电话接电话闪屏)

    小米8手机用打开双4g吗(小米8手机打电话接电话闪屏)

  • 苹果id可以彻底注销吗(苹果id彻底永久注销要去官网)

    苹果id可以彻底注销吗(苹果id彻底永久注销要去官网)

  • 快手超级管理员有几个(快手超级管理员怎么设置)

    快手超级管理员有几个(快手超级管理员怎么设置)

  • 荣耀play可以升级emui10吗(荣耀play可以升级magic ui系统吗)

    荣耀play可以升级emui10吗(荣耀play可以升级magic ui系统吗)

  • 手机语音报时怎么设置(智能手机语音报时)

    手机语音报时怎么设置(智能手机语音报时)

  • 诺基亚x6后盖怎么拆(诺基亚6后盖怎么打开)

    诺基亚x6后盖怎么拆(诺基亚6后盖怎么打开)

  • 华为荣耀3手环怎么关机(华为荣耀3手环价格)

    华为荣耀3手环怎么关机(华为荣耀3手环价格)

  • 不开网络微信运动计步么(不开网络微信运动怎么计步)

    不开网络微信运动计步么(不开网络微信运动怎么计步)

  • 免打扰模式能接电话吗(免打扰模式接不到电话吗)

    免打扰模式能接电话吗(免打扰模式接不到电话吗)

  • 苹果左上角两个小白点(苹果左上角两个椭圆)

    苹果左上角两个小白点(苹果左上角两个椭圆)

  • 苹果x与苹果xs的区别(苹果x与苹果xs的机身材质一样吗)

    苹果x与苹果xs的区别(苹果x与苹果xs的机身材质一样吗)

  • 苹果手机怎么用蓝牙传软件(苹果手机怎么用副号拨打电话)

    苹果手机怎么用蓝牙传软件(苹果手机怎么用副号拨打电话)

  • 电脑开机显示屏显示无信号黑屏解决方法(电脑开机显示屏显示无信号黑屏怎么办)

    电脑开机显示屏显示无信号黑屏解决方法(电脑开机显示屏显示无信号黑屏怎么办)

  • 织梦dede专题不同节点不能选取同样文章的解决方法(织梦专题页模板)

    织梦dede专题不同节点不能选取同样文章的解决方法(织梦专题页模板)

  • 城建税减半征收政策文件
  • 计算土地增值税时允许扣除的项目有
  • 季度所得税计提
  • 没房分手的多吗
  • 房地产开发企业销售自行开发的房地产项目
  • 快递费运费物流费一样吗
  • 销售收入与营业费用的配比
  • 出租土地使用权属于出租不动产吗
  • 2020年个税全年累计扣除如何计算
  • 个体工商户需要报税吗
  • 企业发行债券的目的
  • 企业如何列支个人收入
  • 出口货物如果没收怎么办
  • 公司对外借款怎么做账
  • 合伙企业需要缴纳什么税
  • 2021申请一般纳税人公司的条件
  • 汇算清缴退税现金流量表
  • 增值税税负最终由谁承担
  • 代开的专票开错了怎么办?
  • 普通增值税 税点
  • 境外所得税收抵免政策
  • 资本公积可以怎么处理掉
  • 增值税发票时效性
  • 公司发票限额按什么计算
  • 公司制作小程序定金能放在图物资吗
  • 金税三期得死多少企业
  • 小规模纳税人销售自己使用过的物品
  • 员工可以一起辞职吗
  • 其他应收款会计科目
  • 上个月退货会计分录
  • 销售商品全部退回
  • 往来款的意义
  • 总分机构分摊比例如何确定
  • react+
  • 机器学习中的数学——距离定义(八):余弦距离(Cosine Distance)
  • get_module_base
  • 利润表管理费用包括哪些内容
  • 银行回单打回来会计要做什么
  • 报销具体流程
  • pythonsorted函数的作用
  • 总分公司企业所得税如何申报缴纳
  • 购买电脑的过程
  • 如何在税控盘上变更一般纳税人
  • 未能确认收入的原因
  • 保险公司的奖励制度
  • db2获取当前年月日
  • 第一次建账要填期初余额吗
  • 房地产公司项目开发流程
  • 折旧费用分摊科目是什么
  • 收到法人的借款怎样写摘要
  • 未开票收入如何申报
  • 收到其他公司款项会计分录
  • 短期借款利息计提分录怎么写
  • 收到增值税发票后该如何处理啊?
  • 给员工的奖励怎么做会计分录
  • 收付实现制下收入包括增值税吗
  • sql如何查出重复的数据
  • sql-3
  • windowsxp关机没反应
  • win7系统调节亮度快捷键
  • imac 2010 cpu
  • 如何去掉桌面图标的蓝底
  • f_00000e是什么文件
  • fpd文件是什么意思
  • centos配置yum
  • linux 消耗内存命令
  • 锁屏壁纸设置后不显示怎么办
  • win8.1激活方法
  • opengl画点
  • 获取android id
  • [置顶]bilinovel
  • 微信小程序实现文件上传
  • 学习雷锋好榜样
  • js插件大全
  • unity shader视频教程
  • 如何用javascript
  • js设计模型
  • 出口退税出现预缴怎么办
  • 建行代理贵金属签约
  • 军人残疾证家属享受待遇吗
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设