位置: IT常识 - 正文

pytorch对已有模型的更改(常用的操作)(pytorch model.module)

编辑:rootadmin
pytorch对已有模型的更改(常用的操作)

推荐整理分享pytorch对已有模型的更改(常用的操作)(pytorch model.module),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch有哪些模型,pytorch model.predict,pytorch_model.bin,pytorch model.predict,pytorch_model.bin,pytorch modules,pytorch model.parameters,pytorch model.module,内容如对您有帮助,希望把文章链接给更多的朋友!

本文会做经常性的更改,如有错误或者其他补充的,请各位大佬不吝指点。

如图所示为我们的示例输出的网络结构。

引入创建的模型:

import torchimport simple_modulemod = simple_module.Module()

一、模型的保存与读取

1.整个模型的保存与读取

# 保存整个模型torch.save(mod, '../parameters/mod.pth')# 读取整个模型mod_load = torch.load('../parameters/mod.pth')

2.模型参数的保存与读取(以字典方式保存和读取)

# # 保存模型的参数(以字典的方式保存)torch.save(mod.state_dict(), '../parameters/mod_parameter.pth')# 查看保存了哪些参数print(mod.state_dict().keys())print(mod.state_dict()['feature.0.0.bias'])# 读取模型的参数(以字典的方式读取)mod.load_state_dict(torch.load('../parameters/mod_parameter.pth'))odict_keys(['feature.0.0.weight', 'feature.0.0.bias', 'feature.0.1.weight','feature.0.1.bias', 'feature.0.1.running_mean', 'feature.0.1.running_var', 'feature.0.1.num_batches_tracked', 'feature.1.0.weight', 'feature.1.0.bias', 'feature.1.1.weight', 'feature.1.1.bias', 'feature.1.1.running_mean', 'feature.1.1.running_var', 'feature.1.1.num_batches_tracked', 'classifier.1.weight','classifier.1.bias'])tensor([-0.1721, -0.1222, 0.1023, -0.1484, -0.0547, -0.1922, -0.0796, -0.1784, -0.0233, -0.0271, -0.1018, 0.1875])pytorch对已有模型的更改(常用的操作)(pytorch model.module)

二、模型更改某一层

# 模型修改某一层mod.classifier[1] = torch.nn.Linear(in_features=3072, out_features=20, bias=True)

三、模型删除某些层

# 删除某一层,可以将该层设置为空序列mod.classifier[1] = torch.nn.Sequential()# 可以采用切片的方式删除,这样删除更加彻底mod.classifier = torch.nn.Sequential(*list(mod.classifier.children())[:-1])# 或者直接删除mod.classifier.__delattr__('1')

四、模型添加层(貌似只能在某一个块的末尾添加,后续再查找资料,有大佬可以指点一下)

# 模型添加层mod.classifier.add_module(name='liner', module=torch.nn.Linear(in_features=3072, out_features=100, bias=True))

五、冻结某些层,使得训练时不进行参数更行

1.冻结某一层

# 冻结某一层mod.feature[0][0].weight.requires_grad = False

2.冻结所有的参数

# 冻结所有的参数for param in mod.parameters(): param.requires_grad = False

3.冻结前面某部分的参数,可先将参数名称罗列出来,然后选择一部分的参数名称,利用参数的名称进行冻结。这种方式可以任意地冻结自己想要冻结的层。

no_grad = []for name, value in mod.named_parameters(): # print(name) no_grad.append(name)no_grad = no_grad[:-4]for name, value in mod.named_parameters(): if name in no_grad: value.requires_grad = False else: value.requires_grad = True

 4.还有一种方式,就是只冻结前面几层

i = 0for name, value in mod.named_parameters(): value.requires_grad = False i = i + 1 if i == 4: break;

或者

model_parameters = model.named_parameters()for i in range(freeze): name, value = next(model_parameters) value.requires_grad = False

这是我目前想到的一个方法,还有其他方法的请大佬不吝指点。 

无论哪种方式,都是将对应层的weight的requires_grad设置为False。

5.最后还需要给优化器设置过滤器

# 定义一个fliter,只传入requires_grad=True的模型参数optimizer = optim.SGD(filter(lambda p : p.requires_grad, mod.parameters()), lr=1e-2)
本文链接地址:https://www.jiuchutong.com/zhishi/288943.html 转载请保留说明!

上一篇:vue表单验证rules以及validator验证器的使用(vue表单验证数字)

下一篇:厄勒布鲁附近湖上的仲夏之光,瑞典 (© Anders Jorulf/Getty Images)(厄勒布鲁赛程500)

  • 成都培训公司有哪些效果好推荐

    成都培训公司有哪些效果好推荐

  • Apple Developer 应用程序在 WWDC 2022 之前更新了新功能和错误修复

    Apple Developer 应用程序在 WWDC 2022 之前更新了新功能和错误修复

  • 路由器设置不正确(路由器设置不能上网)(路由器设置不正常怎么办)

    路由器设置不正确(路由器设置不能上网)(路由器设置不正常怎么办)

  • 目录怎么自己编辑(目录这么编写)

    目录怎么自己编辑(目录这么编写)

  • 得物直接拒收可以退款吗(得物拒收普通快递)

    得物直接拒收可以退款吗(得物拒收普通快递)

  • 微信收款语音提醒怎么开启(微信收款语音提示怎么关闭)

    微信收款语音提醒怎么开启(微信收款语音提示怎么关闭)

  • 支付宝怎么充抖音币(支付宝怎么充抖音抖币)

    支付宝怎么充抖音币(支付宝怎么充抖音抖币)

  • 微信来源看不见是怎么回事(微信来源看不到)

    微信来源看不见是怎么回事(微信来源看不到)

  • 为什么加微信号显示用户不存在(为什么加微信号显示不存在)

    为什么加微信号显示用户不存在(为什么加微信号显示不存在)

  • oppo手机自动静音怎么回事(oppo手机自动静音)

    oppo手机自动静音怎么回事(oppo手机自动静音)

  • 小红书直播怎么点赞(小红书直播怎么挂链接)

    小红书直播怎么点赞(小红书直播怎么挂链接)

  • vivoz5和vivoz5i有什么区别(vivoz5i和哪个手机型号一样)

    vivoz5和vivoz5i有什么区别(vivoz5i和哪个手机型号一样)

  • 微视号可以更改吗(微视号能改吗?)

    微视号可以更改吗(微视号能改吗?)

  • i51035g1性能(i51035g1性能相当于台式机什么)

    i51035g1性能(i51035g1性能相当于台式机什么)

  • 华为镜头膜有必要贴吗(华为镜头膜有必要贴膜吗)

    华为镜头膜有必要贴吗(华为镜头膜有必要贴膜吗)

  • 什么软件录屏可以录内部声音(什么软件录屏可以把声音录进去)

    什么软件录屏可以录内部声音(什么软件录屏可以把声音录进去)

  • 在excel中,单元格地址绝对引用的方法是(在excel中,单元格a8的绝对引用应写为( ))

    在excel中,单元格地址绝对引用的方法是(在excel中,单元格a8的绝对引用应写为( ))

  • 打印机一定要连接电脑才能打印吗(打印机一定要连电脑才能用吗)

    打印机一定要连接电脑才能打印吗(打印机一定要连电脑才能用吗)

  • mute是什么按键(mu是什么键位)

    mute是什么按键(mu是什么键位)

  • 拍抖音怎么把抖音号隐藏(拍抖音怎么把抖音二字弄掉)

    拍抖音怎么把抖音号隐藏(拍抖音怎么把抖音二字弄掉)

  • 微博聊天记录怎么恢复(微博聊天记录怎么导出)

    微博聊天记录怎么恢复(微博聊天记录怎么导出)

  • 幻灯片设计模板在哪里(幻灯片设计模板柏林)

    幻灯片设计模板在哪里(幻灯片设计模板柏林)

  • 手机qq怎么冻结账号(手机qq如何冻结)

    手机qq怎么冻结账号(手机qq如何冻结)

  • 快剪辑手机版怎么使用(快剪辑手机版怎么剪辑)

    快剪辑手机版怎么使用(快剪辑手机版怎么剪辑)

  • 苹果a1687是什么型号(苹果a1687是哪里版本)

    苹果a1687是什么型号(苹果a1687是哪里版本)

  • 魅族16如何截屏(魅族16如何截屏图片)

    魅族16如何截屏(魅族16如何截屏图片)

  • 打开网页时显示已重置连接(打开网页时显示无法获取属性)

    打开网页时显示已重置连接(打开网页时显示无法获取属性)

  • Vue中的数据操作(vue数据表)

    Vue中的数据操作(vue数据表)

  • ntpdc命令  查询NTP守护进程(查看ntpdate状态)

    ntpdc命令 查询NTP守护进程(查看ntpdate状态)

  • 个人所得税申报退税多久到账
  • 离职员工个税申报时员工状态依然是雇员
  • 销项负数的分录怎么做
  • 增值税交错了退税怎么退
  • 一个季度30万是不含税吗
  • 无票收入应该怎么做
  • 建筑行业分包款要分项目扣除吗
  • 一般纳税人增值税申报操作流程
  • 当买方违约时,卖方可以得到哪些补救?
  • 支付税点怎么做账
  • 企业长期待摊费用包括
  • 小规模纳税人减免税明细表怎么填
  • 企业的不征税收入用于支出所形成的资产,其计算的折旧
  • 分期收款什么时候交税
  • 一般纳税人有进项无销项
  • 代驾费用入什么二级科目
  • 简易征收发票能抵扣吗
  • 无形资产增值税计入入账价值吗
  • 免征的增值税怎么做账
  • 增值税附加怎么入账
  • 销售商品尚未发出会计分录
  • 原材料结转成本的会计分录例题
  • 以公允价值计量是什么意思
  • cefres.dll是什么
  • php中获取当前时间
  • 一键部署源码
  • 使用灭火器时要对准火焰的什么部位喷射
  • PHP中spl_autoload_register()函数用法实例详解
  • bad block bitmap checksum
  • python tkinter详解
  • 金税盘中的发票修复是什么意思
  • 会务费什么企业可以开
  • 小企业汇兑损失
  • 日记账对方科目是什么意思
  • 不具备独立核算条件的行政单位
  • static在c语言中用法
  • 帝国cms怎么调用文章随机段落
  • 公司收到个人汇款怎么开发票
  • 劳务公司属于什么
  • 企业开办前需要预测现金流量计划吗
  • 土地增值税要计入税金及附加吗
  • 记账凭证核算形式
  • 可抵扣的进项税额要减去进项转出吗
  • 执行迟延履行利息的计算
  • 低值易耗品属于周转材料吗
  • 私营企业员工享受探亲假吗
  • 财务费用利息收入的账务处理
  • 一般纳税人普通发票要交增值税吗
  • 待结算财政款项是什么科目
  • 结转存货跌价准备是什么意思
  • 限售股上市流通是好还是坏
  • 小加工厂怎么开发票
  • 财务费用包括哪几项
  • 软件开发是否属于采购目录
  • mysql8高可用
  • MySQL数据库开发技术电子版
  • mysql开发语言
  • sqlserver 通用分页存储过程
  • mysql密码忘了怎么办?
  • apple watch手表怎么看型号
  • ububtu安装教程
  • ubuntu设置登录用户
  • win10录音机不能用
  • linux内核2.3.20
  • win7重装系统需要重新激活吗
  • win10 windows更新清理删不掉
  • win7系统通过wmic命令
  • 对于javascript理解
  • node.js抓包
  • jquery时钟插件
  • tomcat8.5.8
  • node. js教程
  • unity as
  • Android调用jni获取mac地址
  • 广东国税局发票查询平台
  • 完税证明可以自己在官网打印吗
  • ca用户绑定怎么绑
  • 企业年度申报怎么修改
  • 国税软件下载
  • 销售税金含增值税怎么计算企业所得税
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设