位置: IT常识 - 正文

超参数调优框架optuna(可配合pytorch)(超参数设置)

编辑:rootadmin
超参数调优框架optuna(可配合pytorch) 目录前言一、optuna的使用流程二、结果可视化三、pytorch代码使用optuna前言

推荐整理分享超参数调优框架optuna(可配合pytorch)(超参数设置),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:超参数设置,模型超参数调整,参数调优方法,超参数调优的作用,超参数选择,超参数调节,超参数优化算法,超参数优化,内容如对您有帮助,希望把文章链接给更多的朋友!

在深度学习快速发展的今天,对于不同深度学习模型的超参数优化(hyperparameter optimization),始终是一个比较头痛的问题。在超参较少的情况下,grid search是比较常见的方式,但是随着超参数量的不断增多,特别是对于神经网络而言,训练过程的超参和NN本身的超参组成的参数空间是巨大的,grid search方法会消耗巨大的资源,而且效果很差,因此寻找一个“机器炼丹”的框架十分必要。

optuna 是一个十分常用的超参数调优框架,具有操作简单,嵌入式强和动态调整参数空间等优点。另外还有其他框架也可以进行超参优化,如李沐老师提到的automl等。

一、optuna的使用流程

首先需要在命令行 pip install optuna 载入这个第三方库,载入之后import即可。

optuna中需要注意几个关键的名词: trail::一次实验 study::一次学习过程(包括多次实验)

import optunadef obj(trail):x = trail.suggest_float('x',1,5)return (x-3)*(x-3)stu = optuna.creat_study(study_name = 'test', direction = 'minimize')stu.optimize(obj, n_trials = 50)print(study.best_params)print(study.best_trial)print(study.best_trial.value)超参数调优框架optuna(可配合pytorch)(超参数设置)

该段实例代码中,函数obj定义一个含参数的需要优化的模块,带调整的超参数为 ‘x’ ,返回值为该模块的 objective value。超参x的类型为float,可调整空间为 [1,5] 左右都闭区间,常用的还有suggest_int表示整型,suggest_categorical表示字符串集合。

trail.suggest_int('name', 10, 50)trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh'])

study表示一个学习过程,direction参数为“minimize”表示对函数obj 的返回值(同时也是每次trial的objective value)向最小的方向优化。

二、结果可视化

optuna.visualization中包含了丰富的可视化工具。比较推荐使用的是以下三个:

optuna.visualization.plot_param_importances(stu).show()optuna.visualization.plot_optimization_history(stu).show()optuna.visualization.plot_slice(stu).show()

plot_param_importances 展示各个超参数对结果影响的重要性

plot_optimization_history 展示在n_trail 个trail中每次的objective value和当前的最优解

plot_slice 展示每个超参数在所有trail中取值的分布,以散点图的形式

三、pytorch代码使用optuna

在pytorch构建的MLP中进行使用,可以看到该调参框架是十分灵活的,可以设置训练参数,如batchsize,learning rate,也可也设置NN的参数,如隐藏层数目,激活函数类型等。

import torchfrom torch import nn, optimfrom torch.utils.data import DataLoaderfrom torch.autograd import Variable # 获取变量import optunadef train(batch_size, learning_rate, lossfunc, opt, hidden_layer, activefunc, weightdk,momentum): # 选出一些超参数 trainset_num = 800 testset_num = 50 train_dataset = myDataset(trainset_num) test_dataset = myDataset(testset_num) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) # 创建CNN模型, 并设置损失函数及优化器 model = MLP(hidden_layer, activefunc).cuda() # print(model) if lossfunc == 'MSE': criterion = nn.MSELoss().cuda() elif lossfunc == 'MAE': criterion = nn.L1Loss() # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weightdk) optimizer =optim.RMSprop(model.parameters(),lr=learning_rate,weight_decay=weightdk, momentum=momentum) # 训练过程 for epoch in range(num_epoches): # 训练模式 model.train() for i, data in enumerate(train_loader): inputs, labels, _ = data inputs = Variable(inputs).float().cuda() labels = Variable(labels).float().cuda() # 前向传播 out = model(inputs) # 可以考虑加正则项 train_loss = criterion(out, labels) optimizer.zero_grad() train_loss.backward() optimizer.step() model.eval() testloss = test() #返回测试集合上的MAE print('Test MAE = ', resloss) return reslossdef objective(trail): batchsize = trail.suggest_int('batchsize', 1, 16) lr = trail.suggest_float('lr', 1e-4, 1e-2,step=0.0001) lossfunc = trail.suggest_categorical('loss', ['MSE', 'MAE']) opt = trail.suggest_categorical('opt', ['Adam', 'SGD']) hidden_layer = trail.suggest_int('hiddenlayer', 20, 1200) activefunc = trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh']) weightdekey = trail.suggest_float('weight_dekay', 0, 1,step=0.01) momentum= trail.suggest_float('momentum',0,1,step=0.01) loss = train(batchsize, lr, lossfunc, opt, hidden_layer, activefunc, weightdekey,momentum) return lossif __name__ == '__main__': st=time.time() study = optuna.create_study(study_name='test', direction='minimize') study.optimize(objective, n_trials=500) print(study.best_params) print(study.best_trial) print(study.best_trial.value) print(time.time()-st) optuna.visualization.plot_param_importances(study).show() optuna.visualization.plot_optimization_history(study).show() optuna.visualization.plot_slice(study).show()
本文链接地址:https://www.jiuchutong.com/zhishi/298419.html 转载请保留说明!

上一篇:Transformer 中的mask(transformer add norm)

下一篇:SSD训练数据集流程(学习记录)(ssd训练自己的数据集pytorch)

  • 苹果手机如何关闭浮动圆圈(苹果手机如何关空调)

    苹果手机如何关闭浮动圆圈(苹果手机如何关空调)

  • 华为nova5怎么设置无线充电(华为nova5怎么设置门禁)

    华为nova5怎么设置无线充电(华为nova5怎么设置门禁)

  • 苹果相机怎么设置颗粒(苹果相机怎么设置照片比例)

    苹果相机怎么设置颗粒(苹果相机怎么设置照片比例)

  • 怎么设置拼多多不让别人看到我买的东西(怎么设置拼多多访问相册)

    怎么设置拼多多不让别人看到我买的东西(怎么设置拼多多访问相册)

  • 怎样下载手机彩铃(安卓手机怎么下载彩票)

    怎样下载手机彩铃(安卓手机怎么下载彩票)

  • 声卡为什么要安装机架(声卡为什么安装不了)

    声卡为什么要安装机架(声卡为什么安装不了)

  • 小米10曲面屏边缘发绿(小米10曲面屏边框怎么拆)

    小米10曲面屏边缘发绿(小米10曲面屏边框怎么拆)

  • 华为tit-al00什么型号(华为 trt-al00)

    华为tit-al00什么型号(华为 trt-al00)

  • qq好友辅助成功后多久发短信(qq好友辅助成功没反应)

    qq好友辅助成功后多久发短信(qq好友辅助成功没反应)

  • 苹果手机怎么强行关机快捷键(苹果手机怎么强制关机)

    苹果手机怎么强行关机快捷键(苹果手机怎么强制关机)

  • 华为caztl10是啥型号(华为caz-al10是什么型号?)

    华为caztl10是啥型号(华为caz-al10是什么型号?)

  • ipad闪退黑屏怎么修复(ipad闪退黑屏怎么办)

    ipad闪退黑屏怎么修复(ipad闪退黑屏怎么办)

  • oppor11黑屏打不开怎么办(oppor11黑屏打不开怎么办关机关不上)

    oppor11黑屏打不开怎么办(oppor11黑屏打不开怎么办关机关不上)

  • 拼多多拼不成功钱能退还吗(拼多多拼不成功卖家给我发货)

    拼多多拼不成功钱能退还吗(拼多多拼不成功卖家给我发货)

  • 努比亚红魔3s的闪存规格是多少(努比亚 红魔3s)

    努比亚红魔3s的闪存规格是多少(努比亚 红魔3s)

  • oppo手机动态锁屏壁纸怎么设置(oppo手机动态锁屏壁纸怎么关闭)

    oppo手机动态锁屏壁纸怎么设置(oppo手机动态锁屏壁纸怎么关闭)

  • iphone6s plus电池容量(iphone6splus电池容量是多少)

    iphone6s plus电池容量(iphone6splus电池容量是多少)

  • 咸鱼会员名怎么改更改(咸鱼会员名怎么和淘宝一样)

    咸鱼会员名怎么改更改(咸鱼会员名怎么和淘宝一样)

  • 拼多多恢复删除订单(删除的拼多多怎样恢复)

    拼多多恢复删除订单(删除的拼多多怎样恢复)

  • surface笔的用处大吗(surface笔有用吗)

    surface笔的用处大吗(surface笔有用吗)

  • 快手特别关注有啥用(快手特别关注有提示吗)

    快手特别关注有啥用(快手特别关注有提示吗)

  • iphone照片反色(iphone照片反色app)

    iphone照片反色(iphone照片反色app)

  • 苹果耳机一只忽然不响了(苹果耳机一只忽然不响了可以修吗)

    苹果耳机一只忽然不响了(苹果耳机一只忽然不响了可以修吗)

  • 手机屏幕lcd和led的区别(手机屏幕lcd和led显示器的区别)

    手机屏幕lcd和led的区别(手机屏幕lcd和led显示器的区别)

  • 个人出租非住房房产税怎么计算
  • 应税工资怎么计算出来的
  • 税收分类编码怎么导出来
  • 高档珍珠镶嵌
  • 公司买手表账务处理
  • 发票开具就能做账了吗
  • 期间损益结转错误怎么冲销
  • 小规模纳税人收入账务处理
  • 去年已认证发票红冲怎么报税
  • 接受现金捐赠怎么写分录
  • 河道工程修建维护管理费何时开始停征?
  • 企业工会经费不足,可以向企业拨款吗
  • 物业公司收小区物业费吗
  • 金税盘抵减税款分录
  • 建筑业预缴附加税分录
  • 房屋租赁税费征收的时间是多久
  • 社保显示已录入什么意思
  • 加工费月底需要全部结转吗
  • 收到上月发票怎么写分录
  • 2023增值税免税政策
  • 基金预算收入核算的内容包括
  • 库存商品结转成本
  • 电脑桌面刷新反应迟钝
  • 总承包简易计税
  • 珠宝加工税率是多少
  • 要约与要约邀请的主要区别
  • ps4运行windows
  • 分红个人所得税在哪里查询
  • 发票差额怎样做分录
  • 怎么修改wifi密码视频教程
  • 销售商品收到商业汇票一张会计分录
  • 劳务的完成程度可以采用如下方法确定
  • 公积金补缴需要去柜台吗
  • 购买材料时采购会计分录
  • uniapp vuecli
  • 发票开具的有哪些原则
  • ai implementation
  • php去除字符串中的引号
  • 被称为下一代风华的是
  • java前后端加密解密请求
  • cd播放模式
  • php使用什么开发工具
  • 货币盘盈盘亏账怎么算
  • 外地工程预缴的个人所得税是什么申报
  • 员工办理健康证需要什么材料
  • 个体工商户税务登记需要哪些资料
  • python文件间传递参数
  • 消防收费标准
  • 汽车租赁费怎么赋码
  • 水电费没有票怎么做账
  • 周转材料应该计入什么科目
  • 租赁房产税计税依据及计算方式是什么
  • 应收账款资产减值准备可以在所得税前扣除吗
  • 库存商品进项税额转出分录怎么写
  • 企业股东撤资如何清算
  • 多收发票会计分录
  • 销售一批产品给丙公司,该批产品标价200万yuan
  • 如何开具发票?
  • 企业搬迁补偿款免税的法律依据是什么
  • mysql数据库sid
  • mysql多字段数据
  • macbookair网页视频看不了
  • linux怎么vi
  • windows自带软件有哪些
  • centos7 lvcreate
  • msg0是什么文件
  • win7桌面壁纸自动更换关闭
  • cocos2d-x教程
  • js array insert
  • 怎么申请返回
  • linux中tar
  • Linux Shell中判断进程是否存在的方法
  • 字符串截取用什么方法
  • javascript用处
  • javascript面向对象精要
  • 财务报表的收入平稳
  • 国家税务局申报系统操作步骤说明在哪里看
  • uk怎么添加发票
  • 预交增值税附加税率
  • 河南省国家税务局发票查询官网
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设