位置: IT常识 - 正文

pytorch对网络层的增,删, 改, 修改预训练模型结构(pytorch自定义网络层)

编辑:rootadmin
pytorch对网络层的增,删, 改, 修改预训练模型结构 #下载模型参数model.load_state_dict(torch.load('model.pth'))#再加载网络的参数torch.load('model.pth')是获得网络参数

推荐整理分享pytorch对网络层的增,删, 改, 修改预训练模型结构(pytorch自定义网络层),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch怎么训练网络,pytorch 网络,pytorch网络搭建,pytorch定义网络,pytorch cnn网络,pytorch bp网络,pytorch输出网络结构,pytorch cnn网络,内容如对您有帮助,希望把文章链接给更多的朋友!

1.我们使用vgg11网络做示例, 看一下网络结构:

加载本地的模型:

vgg16 = models.vgg16(pretrained=False)#打印出预训练模型的参数vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))

加载库中的模型

import torchimport torch.nn as nnfrom torchvision import modelsnet = models.vgg11(pretrained=True)print(net)

1)(1). 在网络中添加一层:

net网络是一个树型结构, net下面有三个结点,分别是(features, avgpoll, classifier), 我们先在features结点添加一层’lastlayer'层

net.features.add_module('lastlayer', nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1))print(net) 1)(2). 在classifier结点添加一个线性层:net.classifier.add_module('Linear', nn.Linear(1000, 10))print(net)

2)(1)修改网络中的某一层(features 结点举例):net.features[8] = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))print(net)

 2)(2)修改网络中的某一层(classifier结点举例):net.classifier[6] = nn.Linear(1000, 5)print(net)pytorch对网络层的增,删, 改, 修改预训练模型结构(pytorch自定义网络层)

注意: 这里我尝试对Linear这一层进行更新, 但是Linear名字是字符串, 提取不出来,所以应该在之前添加网络时候, 名字不要取字符串, 否则会报错  ‘  'str' object cannot be interpreted as an integer’。 

 3)(1) 网络层的删除(features举例) classifier结点的操作相同。

直接使用nn.Sequential()对改层设置为空即可

net.features[13] = nn.Sequential()print(net)

 4)冻结网络中某些层 (直接使该层的requires_grad = False)即可, 这样在反向传播的时候,不会更新该层的参数#冻结指定层的预训练参数:net.feature[26].weight.requires_grad = False5). 第二种对网络结构的操作方法:net.features = nn.Sequential(*list(net.features.children())[:-4])

可以看到后面4层被去除了, 就是说可以使用列表切片的方法来删除网络层

net.classifier 对应 net.classifier.children()

net.features 对应 net.features.children()

  1. 先加载网络结构

自己的模型, model的类要有定义才可以, 如果在其他.py文件中,可以导入文件,然后用文件中的类实例化对象。model = torch.load(PATH)

 2.再加载网络参数

#下载模型参数

model.load_state_dict(torch.load('model.pth'))#再加载网络的参数torch.load('model.pth')是获得网络参数
本文链接地址:https://www.jiuchutong.com/zhishi/299377.html 转载请保留说明!

上一篇:vue 项目兼容 IE 浏览器(vue项目兼容ie9以上浏览器)

下一篇:【CSS】CSS 背景设置 ⑨ ( 背景半透明设置 )(css背景图)

  • 个税退还会计分录
  • 个人所得税谁交
  • 虚开普票的立案标准
  • 年报现金流量表可以不填吗
  • 母子公司间借款利息交税吗
  • 企业内部转账应注意什么
  • 售后维修的配件开维修费
  • 贷款减值准备如何计提
  • 大病医疗专项附加扣除标准举例
  • 公司从个人手中租房不能取得发票
  • 一般纳税人农业增值税如何申报
  • 股东投入的资金要交税吗
  • 开完红字发票后,正数发票如何开具?
  • 增值税专用发票抵扣期限
  • 跨季度的发票怎么冲销
  • 职工个人负担的医疗保险可以在计算个人所得税前扣除
  • 代收代缴水电费能开发票吗
  • 公司与股东的往来款现金流量表
  • 分期收款销售的商品属于存货吗
  • 延期收款利息如何算
  • 如何解决win7系统搜不到蓝牙耳机
  • 苹果笔记本如何切换输入法
  • 债务优化是做什么的工作
  • 企业所得税申报流程
  • 被税务查了
  • 印花税的征收范围
  • 企业收缩案例
  • 代理公司可以开服务费发票吗
  • dhcp存在哪些安全隐患
  • 进口报关费用会计分录
  • ubuntu 20.04.1
  • php字符串赋值
  • 依夫城堡
  • 微信小程序前端源码
  • 季度缴纳企业所得税计算方法
  • 研发和技术服务税率3%
  • 联邦学习(FL)+差分隐私(DP)
  • 资产总额怎么计算公式
  • elementui表格自定义排序
  • 公司的一项专利多少钱
  • 什么情况下不用割包皮
  • 资产负债表中应交税费为负数是什么意思
  • RedHat6.5/CentOS6.5安装Mysql5.7.20的教程详解
  • 小企业一年需要缴纳多少税
  • 《开具红字增值税专用发票通知单》
  • 开具红字增值税普通发票
  • 存款利息收入一般是多少
  • 应交税费在会计科目的借贷方向
  • 跨年发票两大原则
  • 小规模纳税人缴纳的增值税计入成本吗
  • 三方协议代付的钱在哪里
  • 承兑汇票收据开什么发票
  • 应收账款的明细科目一般按照什么设置
  • 房地产企业扣除土地价款如何申报
  • 出口货物不免不退
  • 工会发放节日慰问品种类
  • 筹建期费用需要结转吗
  • 信用代码证过期了6年怎么办理
  • Fedora 9.0 Apache+PHP+MYSQL 环境安装
  • centos设置hostname
  • macbook怎么开启
  • 2021年win10累积更新
  • msg是什么文件
  • win7如何限制网速
  • Win10控制面板打不开
  • 何为黄金茶
  • unity cpu优化
  • linux中mysql备份shell脚本代码
  • rpg游戏脚本已经被备份
  • jquery 延迟对象
  • 短信发送器
  • js原型作用
  • unity官方插件
  • android设计模式书籍
  • pycharm flask框架
  • 浙里办怎么给小孩子缴医保费
  • 内蒙古电子税务局app官方下载
  • 成立税务师事务所一定要执业会员吗
  • 广东地税服务电话
  • 发票清单盖章样本图片
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设