位置: IT常识 - 正文

超参数调优框架optuna(可配合pytorch)(超参数设置)

编辑:rootadmin
超参数调优框架optuna(可配合pytorch) 目录前言一、optuna的使用流程二、结果可视化三、pytorch代码使用optuna前言

推荐整理分享超参数调优框架optuna(可配合pytorch)(超参数设置),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:超参数设置,模型超参数调整,参数调优方法,超参数调优的作用,超参数选择,超参数调节,超参数优化算法,超参数优化,内容如对您有帮助,希望把文章链接给更多的朋友!

在深度学习快速发展的今天,对于不同深度学习模型的超参数优化(hyperparameter optimization),始终是一个比较头痛的问题。在超参较少的情况下,grid search是比较常见的方式,但是随着超参数量的不断增多,特别是对于神经网络而言,训练过程的超参和NN本身的超参组成的参数空间是巨大的,grid search方法会消耗巨大的资源,而且效果很差,因此寻找一个“机器炼丹”的框架十分必要。

optuna 是一个十分常用的超参数调优框架,具有操作简单,嵌入式强和动态调整参数空间等优点。另外还有其他框架也可以进行超参优化,如李沐老师提到的automl等。

一、optuna的使用流程

首先需要在命令行 pip install optuna 载入这个第三方库,载入之后import即可。

optuna中需要注意几个关键的名词: trail::一次实验 study::一次学习过程(包括多次实验)

import optunadef obj(trail):x = trail.suggest_float('x',1,5)return (x-3)*(x-3)stu = optuna.creat_study(study_name = 'test', direction = 'minimize')stu.optimize(obj, n_trials = 50)print(study.best_params)print(study.best_trial)print(study.best_trial.value)超参数调优框架optuna(可配合pytorch)(超参数设置)

该段实例代码中,函数obj定义一个含参数的需要优化的模块,带调整的超参数为 ‘x’ ,返回值为该模块的 objective value。超参x的类型为float,可调整空间为 [1,5] 左右都闭区间,常用的还有suggest_int表示整型,suggest_categorical表示字符串集合。

trail.suggest_int('name', 10, 50)trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh'])

study表示一个学习过程,direction参数为“minimize”表示对函数obj 的返回值(同时也是每次trial的objective value)向最小的方向优化。

二、结果可视化

optuna.visualization中包含了丰富的可视化工具。比较推荐使用的是以下三个:

optuna.visualization.plot_param_importances(stu).show()optuna.visualization.plot_optimization_history(stu).show()optuna.visualization.plot_slice(stu).show()

plot_param_importances 展示各个超参数对结果影响的重要性

plot_optimization_history 展示在n_trail 个trail中每次的objective value和当前的最优解

plot_slice 展示每个超参数在所有trail中取值的分布,以散点图的形式

三、pytorch代码使用optuna

在pytorch构建的MLP中进行使用,可以看到该调参框架是十分灵活的,可以设置训练参数,如batchsize,learning rate,也可也设置NN的参数,如隐藏层数目,激活函数类型等。

import torchfrom torch import nn, optimfrom torch.utils.data import DataLoaderfrom torch.autograd import Variable # 获取变量import optunadef train(batch_size, learning_rate, lossfunc, opt, hidden_layer, activefunc, weightdk,momentum): # 选出一些超参数 trainset_num = 800 testset_num = 50 train_dataset = myDataset(trainset_num) test_dataset = myDataset(testset_num) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) # 创建CNN模型, 并设置损失函数及优化器 model = MLP(hidden_layer, activefunc).cuda() # print(model) if lossfunc == 'MSE': criterion = nn.MSELoss().cuda() elif lossfunc == 'MAE': criterion = nn.L1Loss() # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weightdk) optimizer =optim.RMSprop(model.parameters(),lr=learning_rate,weight_decay=weightdk, momentum=momentum) # 训练过程 for epoch in range(num_epoches): # 训练模式 model.train() for i, data in enumerate(train_loader): inputs, labels, _ = data inputs = Variable(inputs).float().cuda() labels = Variable(labels).float().cuda() # 前向传播 out = model(inputs) # 可以考虑加正则项 train_loss = criterion(out, labels) optimizer.zero_grad() train_loss.backward() optimizer.step() model.eval() testloss = test() #返回测试集合上的MAE print('Test MAE = ', resloss) return reslossdef objective(trail): batchsize = trail.suggest_int('batchsize', 1, 16) lr = trail.suggest_float('lr', 1e-4, 1e-2,step=0.0001) lossfunc = trail.suggest_categorical('loss', ['MSE', 'MAE']) opt = trail.suggest_categorical('opt', ['Adam', 'SGD']) hidden_layer = trail.suggest_int('hiddenlayer', 20, 1200) activefunc = trail.suggest_categorical('active', ['relu', 'sigmoid', 'tanh']) weightdekey = trail.suggest_float('weight_dekay', 0, 1,step=0.01) momentum= trail.suggest_float('momentum',0,1,step=0.01) loss = train(batchsize, lr, lossfunc, opt, hidden_layer, activefunc, weightdekey,momentum) return lossif __name__ == '__main__': st=time.time() study = optuna.create_study(study_name='test', direction='minimize') study.optimize(objective, n_trials=500) print(study.best_params) print(study.best_trial) print(study.best_trial.value) print(time.time()-st) optuna.visualization.plot_param_importances(study).show() optuna.visualization.plot_optimization_history(study).show() optuna.visualization.plot_slice(study).show()
本文链接地址:https://www.jiuchutong.com/zhishi/298419.html 转载请保留说明!

上一篇:Transformer 中的mask(transformer add norm)

下一篇:SSD训练数据集流程(学习记录)(ssd训练自己的数据集pytorch)

  • oppor17黑屏怎么强制重启(oppor17黑屏怎么回事)

    oppor17黑屏怎么强制重启(oppor17黑屏怎么回事)

  • 小米摄像头支持128g吗(小米摄像头支持onvif协议吗)

    小米摄像头支持128g吗(小米摄像头支持onvif协议吗)

  • 不拉黑不删除怎么才能不收到信息(不拉黑不删除怎么隐藏微信消息)

    不拉黑不删除怎么才能不收到信息(不拉黑不删除怎么隐藏微信消息)

  • tcp协议称为什么(tcp协议属于什么协议)

    tcp协议称为什么(tcp协议属于什么协议)

  • 5g单模和双模是什么意思(5g单模和双模手机哪个好)

    5g单模和双模是什么意思(5g单模和双模手机哪个好)

  • powerpoint中的版式指的是(ppt版本)

    powerpoint中的版式指的是(ppt版本)

  • 小米10息屏显示费电吗

    小米10息屏显示费电吗

  • oppoa31是全网通吗(oppoa32全网通)

    oppoa31是全网通吗(oppoa32全网通)

  • 怎么设置没开通朋友圈(怎么设置没开通花呗功能)

    怎么设置没开通朋友圈(怎么设置没开通花呗功能)

  • vivox30手机有几个颜色(vivox30规格)

    vivox30手机有几个颜色(vivox30规格)

  • 苹果黑名单如何移除(苹果黑名单如何设置)

    苹果黑名单如何移除(苹果黑名单如何设置)

  • 手机可以改银行卡密码吗(手机可以改银行密码吗)

    手机可以改银行卡密码吗(手机可以改银行密码吗)

  • 红米note8pro来电转移怎么设置(红米note8pro来电转接怎么设置)

    红米note8pro来电转移怎么设置(红米note8pro来电转接怎么设置)

  • 苹果手机快捷指令库在哪里(苹果手机快捷指令)

    苹果手机快捷指令库在哪里(苹果手机快捷指令)

  • 苹果系统怎么玩安卓区(苹果系统怎么玩皮卡堂)

    苹果系统怎么玩安卓区(苹果系统怎么玩皮卡堂)

  • nova是什么系列(hi nova是什么牌子手机)

    nova是什么系列(hi nova是什么牌子手机)

  • 华为nova5pro充电时间(华为nova5pro充电器型号)

    华为nova5pro充电时间(华为nova5pro充电器型号)

  • etc重新安装移动后如何重新激活(etc重新安装移动后如何重新激活收费吗)

    etc重新安装移动后如何重新激活(etc重新安装移动后如何重新激活收费吗)

  • 怎么制作腾讯视频短片(怎么制作腾讯视频会议)

    怎么制作腾讯视频短片(怎么制作腾讯视频会议)

  • 快手隐藏动态干嘛的(快手设置中在动态中隐藏自己的动态)

    快手隐藏动态干嘛的(快手设置中在动态中隐藏自己的动态)

  • 华为mate10怎么打开otg(华为mate10怎么打开高清通话)

    华为mate10怎么打开otg(华为mate10怎么打开高清通话)

  • 全民k歌怎样禁止访问(全民k歌怎样禁止陌生人评论)

    全民k歌怎样禁止访问(全民k歌怎样禁止陌生人评论)

  • 苏宁如何解绑银行卡(苏宁如何解绑银行卡绑定)

    苏宁如何解绑银行卡(苏宁如何解绑银行卡绑定)

  • 【数据库】SQL语句(sql数据库语句基本语法)

    【数据库】SQL语句(sql数据库语句基本语法)

  • 基于Vision Transformer的图像去雾算法研究与实现(附源码)(基于专业性的家校双向互动,需要家长的学校教育参与)

    基于Vision Transformer的图像去雾算法研究与实现(附源码)(基于专业性的家校双向互动,需要家长的学校教育参与)

  • 学会这两种方式,我们就可以免费使用chatgpt(学会这两种方式英语)

    学会这两种方式,我们就可以免费使用chatgpt(学会这两种方式英语)

  • 租房代收水电费税率
  • 存货跌价准备计提原则
  • 最新的税收政策
  • 高新企业帐务流程
  • 电力工程公司岗位职责
  • 退回房租含税的情况怎么入账?
  • 外购技术服务费包括哪些
  • 固定资产减半征收2%申报如何填增值税纳税申报表
  • 认缴制下实收资本如何缴纳印花税
  • 税局如何查无票收入
  • 汇算所得税中“以前年度多缴的所得税额在本年抵减额”怎么填 ?
  • 减税必须通过开户银行吗
  • 核定经营额是一个季度还是一个月
  • 存货进项税额转出会计处理
  • 全资子公司向母公司提供劳务服务怎么做账
  • 公司转让税费如何计算
  • 购买法下购买成本包括
  • 1697508432
  • 个人所得税合并申报
  • 行业协会会费收缴标准
  • 如何测试网络延迟
  • 可引导的macos
  • laravel数据迁移
  • PHP:pg_cancel_query()的用法_PostgreSQL函数
  • 股份支付如何缴纳个人所得税?
  • 浅谈特殊儿童的融合教育论文
  • 项目版本管理是什么
  • 在代开发票时已经预缴个人所得税了,怎么处理?
  • 赡养老人专项扣除标准
  • php504错误
  • vue router怎么传值
  • 子公司提取盈余公积 合并抵消
  • vue3使用教程
  • 大学毕业后送快递
  • 个体户税率征收
  • 季末资产总额怎么计算出来的
  • 织梦怎么调用当前栏目下的文章
  • dedecms怎么更换模板
  • 社保工伤退回分录
  • 公司滞纳金员工承担怎么做账
  • 应付账款发生坏账怎么办
  • sqlserver获取数据库名
  • mysql怎样
  • 非独立核算门市部销售自产应税消费品
  • 应付职工薪酬的会计科目
  • 物流公司卖车合法么
  • 库存商品余额在借方是什么意思
  • 不得从销项税额中抵扣的进项税额,不得计提加计抵减额
  • 销售退回 所得税
  • 住宿发票 抵扣
  • 暂估应付账款的科目编码
  • 行政事业单位支出范围和标准
  • 出口退税登记的内容
  • 长期股权投资减值准备是什么意思
  • 物流货到付款可以吗
  • SQL普通表转分区表的方法
  • 自动清理河道垃圾船
  • win8的应用商店
  • 移动u盘的作用
  • DxO Optics Pro 9 激活破解安装详细图文教程
  • win7系统禁止更新
  • wind10怎么打开摄像头
  • win10系统环境设置
  • xp 注册
  • 此电脑右键
  • Win10系统的电脑可装Wlin7系统吗
  • win7开机没反应怎么办
  • windows10周年更新
  • linux配置命令
  • 深入分析的成语
  • div与span区别及用法
  • jquery生成div
  • css中背景图片设置
  • jquery 文本框
  • bootstrap要学到什么程度
  • 深入浅出html pdf中文版
  • 重庆地方税务局12366
  • 电子税务局app扫脸认证
  • 国税地位比地税高吗
  • 电子增值税专用发票和纸质增值税专用发票的区别
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设