位置: IT常识 - 正文

pytorch对网络层的增,删, 改, 修改预训练模型结构(pytorch自定义网络层)

编辑:rootadmin
pytorch对网络层的增,删, 改, 修改预训练模型结构 #下载模型参数model.load_state_dict(torch.load('model.pth'))#再加载网络的参数torch.load('model.pth')是获得网络参数

推荐整理分享pytorch对网络层的增,删, 改, 修改预训练模型结构(pytorch自定义网络层),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch怎么训练网络,pytorch 网络,pytorch网络搭建,pytorch定义网络,pytorch cnn网络,pytorch bp网络,pytorch输出网络结构,pytorch cnn网络,内容如对您有帮助,希望把文章链接给更多的朋友!

1.我们使用vgg11网络做示例, 看一下网络结构:

加载本地的模型:

vgg16 = models.vgg16(pretrained=False)#打印出预训练模型的参数vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))

加载库中的模型

import torchimport torch.nn as nnfrom torchvision import modelsnet = models.vgg11(pretrained=True)print(net)

1)(1). 在网络中添加一层:

net网络是一个树型结构, net下面有三个结点,分别是(features, avgpoll, classifier), 我们先在features结点添加一层’lastlayer'层

net.features.add_module('lastlayer', nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1))print(net) 1)(2). 在classifier结点添加一个线性层:net.classifier.add_module('Linear', nn.Linear(1000, 10))print(net)

2)(1)修改网络中的某一层(features 结点举例):net.features[8] = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))print(net)

 2)(2)修改网络中的某一层(classifier结点举例):net.classifier[6] = nn.Linear(1000, 5)print(net)pytorch对网络层的增,删, 改, 修改预训练模型结构(pytorch自定义网络层)

注意: 这里我尝试对Linear这一层进行更新, 但是Linear名字是字符串, 提取不出来,所以应该在之前添加网络时候, 名字不要取字符串, 否则会报错  ‘  'str' object cannot be interpreted as an integer’。 

 3)(1) 网络层的删除(features举例) classifier结点的操作相同。

直接使用nn.Sequential()对改层设置为空即可

net.features[13] = nn.Sequential()print(net)

 4)冻结网络中某些层 (直接使该层的requires_grad = False)即可, 这样在反向传播的时候,不会更新该层的参数#冻结指定层的预训练参数:net.feature[26].weight.requires_grad = False5). 第二种对网络结构的操作方法:net.features = nn.Sequential(*list(net.features.children())[:-4])

可以看到后面4层被去除了, 就是说可以使用列表切片的方法来删除网络层

net.classifier 对应 net.classifier.children()

net.features 对应 net.features.children()

  1. 先加载网络结构

自己的模型, model的类要有定义才可以, 如果在其他.py文件中,可以导入文件,然后用文件中的类实例化对象。model = torch.load(PATH)

 2.再加载网络参数

#下载模型参数

model.load_state_dict(torch.load('model.pth'))#再加载网络的参数torch.load('model.pth')是获得网络参数
本文链接地址:https://www.jiuchutong.com/zhishi/299377.html 转载请保留说明!

上一篇:vue 项目兼容 IE 浏览器(vue项目兼容ie9以上浏览器)

下一篇:【CSS】CSS 背景设置 ⑨ ( 背景半透明设置 )(css背景图)

  • 建筑业固定资产折旧费用科目是什么
  • 耕地占用税的税目
  • acca b/f
  • 怎么用一证通报税
  • 应发工资和实发工资计算公式excel
  • 事业单位合并财务交接
  • 研发费用长期是否可控
  • 总分类一般采用什么格式
  • 营改增后小规模都是三个点吗
  • 房租费简易征收税率
  • 建筑工程预收款开票会计分录
  • 借主营业务成本贷应付账款
  • 申报缴纳印花税,取得银行缴税凭证
  • 公司的房产税如何征收
  • 自建厂房转固定资产如何办理手续
  • win7改win10详细教程
  • 项目投资净现值计算公式
  • 运输业什么进项税抵扣
  • 购买产品优惠计入什么科目
  • linux的基础知识
  • 若依框架是谁写的
  • 按销售收入比例分摊进项税额按月还是按年
  • 一般纳税人做账报税的整个流程详细
  • 增值税发票开红字发票后账务处理
  • 外贸企业有哪些公司青岛
  • 逾期未收回包装物押金税率
  • 短期借贷属于负债类科目
  • 补提去年的所得税费用是怎么做分录?
  • 微信小程序项目开发实战
  • 实际交印花税会计分录
  • 海关进口税可以抵扣吗
  • php日期时间函数
  • php中的pdo
  • 工作服列支什么科目
  • k8s kubelet
  • php单例模式
  • 残保金交错了怎么办
  • 发票勾选平台进入后没有什么内容
  • phpcms怎么修改模板风格
  • 结转销售成本的凭证需要附件吗
  • 购房发票可以对折吗
  • 报销人和收款人不一致钱打到哪个账户
  • 外购存货的初始成本由买价加采购费用构成
  • 个人向企业借贷违法吗
  • 临时工资怎么核算
  • 营运资金周转率是什么指标
  • 广告费的税额计入哪里
  • 售后回租如何做会计处理
  • 增值税扣税公式
  • 小规模纳税人涉嫌虚开发票
  • 如何理解出口退税的意义
  • 拆迁房视同销售成本可以抵减吗?
  • 专项费用会计分录
  • 小微企业取得的进项税能不能抵扣
  • 销售固定资产怎么做账务处理
  • 新成立的公司需要年检吗
  • unix系统采用什么结构
  • windows vista可以换7吗
  • window系统怎么取消开机密码
  • Win10 Mobile RS2预览版14915上手视频评测
  • mac 鼠标调整
  • linux运维常用命令汇总
  • nginx文件服务器
  • win10系统开机后任务栏无响应怎么解决
  • jquery实战
  • 怎么配置nodejs
  • cocos2djs
  • jquery教程w3c
  • div显示边框线
  • 简单的安卓代码
  • unity+
  • 超详细的卡拉赞攻略
  • android app安全
  • jQuery+Ajax+PHP弹出层异步登录效果(附源码下载)
  • 进项转出怎么做
  • 电子签章在电脑上怎样加印章
  • 出口退税外汇汇率如何确定
  • 010是哪个市区的电话号码
  • 车船税为什么有时候不用交
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设