位置: IT常识 - 正文

【torch.nn.Parameter 】参数相关的介绍和使用

编辑:rootadmin
【torch.nn.Parameter 】参数相关的介绍和使用 文章目录torch.nn.Parameter基本介绍参数构造参数访问参数初始化使用内置初始化自定义初始化参数绑定参考torch.nn.Parameter基本介绍

推荐整理分享【torch.nn.Parameter 】参数相关的介绍和使用,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

torch.nn.Parameter是继承自torch.Tensor的子类,其主要作用是作为nn.Module中的可训练参数使用。它与torch.Tensor的区别就是nn.Parameter会自动被认为是module的可训练参数,即加入到parameter()这个迭代器中去。

具体格式如下:

torch.nn.parameter.Parameter(data=None, requires_grad=True)

其中 data 为待传入的 Tensor,requires_grad 默认为 True。

事实上,torch.nn 中提供的模块中的参数均是 nn.Parameter 类,例如:

module = nn.Linear(3, 3)type(module.weight)# torch.nn.parameter.Parametertype(module.bias)# torch.nn.parameter.Parameter参数构造

nn.Parameter可以看作是一个类型转换函数,将一个不可训练的类型 Tensor 转换成可以训练的类型 parameter ,并将这个 parameter 绑定到这个module 里面nn.Parameter()添加的参数会被添加到Parameters列表中,会被送入优化器中随训练一起学习更新

此时调用 parameters()方法会显示参数。读者可自行体会以下两端代码:

""" 代码片段一 """class Net(nn.Module): def __init__(self): super().__init__() self.weight = torch.randn(3, 3) self.bias = torch.randn(3) def forward(self, inputs): passnet = Net()print(list(net.parameters()))# []""" 代码片段二 """class Net(nn.Module): def __init__(self): super().__init__() self.weight = **nn.Parameter**(torch.randn(3, 3)) # 将tensor转换成parameter类型 self.bias = **nn.Parameter**(torch.randn(3)) def forward(self, inputs): passnet = Net()print(list(**net.parameters()**)) # 显示参数# [Parameter containing:# tensor([[-0.4584, 0.3815, -0.4522],# [ 2.1236, 0.7928, -0.7095],# [-1.4921, -0.5689, -0.2342]], requires_grad=True), Parameter containing:# tensor([-0.6971, -0.7651, 0.7897], requires_grad=True)]

nn.Parameter相当于把传入的数据包装成一个参数,如果要直接访问/使用其中的数据而非参数本身,可对 nn.Parameter对象调用 data属性:

a = torch.tensor([1, 2, 3]).to(torch.float32)param = nn.Parameter(a)print(param)# Parameter containing:# tensor([1., 2., 3.], requires_grad=True)print(param.data)# tensor([1., 2., 3.])参数访问

nn.Module 中有 **state_dict()** 方法,该方法将以字典形式返回模块的所有状态,包括模块的参数和 persistent buffers ,字典的键就是对应的参数/缓冲区的名称。

【torch.nn.Parameter 】参数相关的介绍和使用

由于所有模块都继承 nn.Module,因此我们可以对任意的模块调用 state_dict() 方法以查看状态:

linear_layer = nn.Linear(2, 2)print(linear_layer.state_dict())# OrderedDict([('weight', tensor([[ 0.2602, -0.2318],# [-0.5192, 0.0130]])), ('bias', tensor([0.5890, 0.2476]))])print(linear_layer.state_dict().keys())# odict_keys(['weight', 'bias'])

对于线性层,除了 state_dict()之外,我们还可以对其直接调用相应的属性,如下:

linear_layer = nn.Linear(2, 1)print(linear_layer.weight)# Parameter containing:# tensor([[-0.1990, 0.3394]], requires_grad=True)print(linear_layer.bias)# Parameter containing:# tensor([0.2697], requires_grad=True)

需要注意的是以上返回的均为参数对象,如需使用其中的数据,可调用 data 属性。

参数初始化使用内置初始化

对于下面的单隐层网络,我们想对其中的两个线性层应用内置初始化器

class Net(nn.Module): def __init__(self): super().__init__() self.layers = nn.Sequential( nn.Linear(3, 2), nn.ReLU(), nn.Linear(2, 3), ) def forward(self, X): return self.layers(X)

假设权重从 N(0,1) 中采样,偏置全部初始化为 0,则初始化代码如下:

def init_normal(module): # 需要判断子模块是否为nn.Linear类,因为激活函数没有参数 if type(module) == nn.Linear: nn.init.normal_(module.weight, mean=0, std=1) nn.init.zeros_(module.bias)net = Net()net.apply(init_normal)for param in net.parameters(): print(param)# Parameter containing:# tensor([[-0.3560, 0.8078, -2.4084],# [ 0.1700, -0.3217, -1.3320]], requires_grad=True)# Parameter containing:# tensor([0., 0.], requires_grad=True)# Parameter containing:# tensor([[-0.8025, -1.0695],# [-1.7031, -0.3068],# [-0.3499, 0.4263]], requires_grad=True)# Parameter containing:# tensor([0., 0., 0.], requires_grad=True)

对 net调用 apply方法则会递归地对其下所有的子模块应用 init_normal函数。

自定义初始化

如果我们想要自定义初始化,例如使用以下的分布来初始化网络的权重:

def my_init(module): if type(module) == nn.Linear: nn.init.uniform_(module.weight, -10, 10) mask = module.weight.data.abs() >= 5 module.weight.data *= masknet = Net()net.apply(my_init)for param in net.parameters(): print(param)# Parameter containing:# tensor([[-0.0000, -5.9610, 8.0000],# [-0.0000, -0.0000, 7.6041]], requires_grad=True)# Parameter containing:# tensor([ 0.4058, -0.2891], requires_grad=True)# Parameter containing:# tensor([[ 0.0000, -0.0000],# [-6.9569, -9.5102],# [-9.0270, -0.0000]], requires_grad=True)# Parameter containing:# tensor([ 0.2521, -0.1500, -0.1484], requires_grad=True)参数绑定

对于一个三隐层网络:

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 1))

如果我们想让第二个隐层和第三个隐层共享参数,则可以这样做:

shared = nn.Linear(8, 8)net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), shared, nn.ReLU(), shared, nn.ReLU(), nn.Linear(8, 1))参考

PyTorch学习笔记(六)–Sequential类、参数管理与GPU_Lareges的博客-CSDN博客_sequential类

torch.nn 中文文档

Python的torch.nn.Parameter初始化方法_昊大侠的博客-CSDN博客_torch.nn.parameter初始化

本文链接地址:https://www.jiuchutong.com/zhishi/293777.html 转载请保留说明!

上一篇:如何vue使用ant design Vue中的select组件实现下拉分页加载数据,并解决存在的一个问题。(ant desgin-vue)

下一篇:7.25 web前端-淘宝首页设计(淘宝前端用什么写的)

  • b站app字体大小如何设置(b站app字体大小怎么设置)

    b站app字体大小如何设置(b站app字体大小怎么设置)

  • 华为nova5z可以快充吗(nova5可以快充吗)

    华为nova5z可以快充吗(nova5可以快充吗)

  • 微信怎么查看朋友圈点赞和回复记录(微信怎么查看朋友圈屏蔽的人)

    微信怎么查看朋友圈点赞和回复记录(微信怎么查看朋友圈屏蔽的人)

  • P40与P30的区别(p30和p40有什么区别)

    P40与P30的区别(p30和p40有什么区别)

  • 苹果x和11区别(苹果x和11一样吗)

    苹果x和11区别(苹果x和11一样吗)

  • 苹果笔记本无线网开关在哪(苹果笔记本无线投屏)

    苹果笔记本无线网开关在哪(苹果笔记本无线投屏)

  • 苹果wapi打开还是关闭(苹果手机wapi打开好还是不打开)

    苹果wapi打开还是关闭(苹果手机wapi打开好还是不打开)

  • 网关错误是什么意思(网关错误是什么原因)

    网关错误是什么意思(网关错误是什么原因)

  • 微信几年前的聊天记录可以找到吗(微信几年前的聊天记录有办法恢复么)

    微信几年前的聊天记录可以找到吗(微信几年前的聊天记录有办法恢复么)

  • 为什么小米手环4微信信息来没有提示(为什么小米手环表盘无法同步)

    为什么小米手环4微信信息来没有提示(为什么小米手环表盘无法同步)

  • 充电仓需要充电几小时(充电仓需要充电多长时间)

    充电仓需要充电几小时(充电仓需要充电多长时间)

  • 美团快驴是做什么的(美团快驴工作怎么样?)

    美团快驴是做什么的(美团快驴工作怎么样?)

  • 网线哪两根是正负极(网线哪两根是正极线)

    网线哪两根是正负极(网线哪两根是正极线)

  • 5sa1533是什么版本(iphone5sa1530是什么版本)

    5sa1533是什么版本(iphone5sa1530是什么版本)

  • 微信注销多久申请新号(微信注销要多久才可以重新申请)

    微信注销多久申请新号(微信注销要多久才可以重新申请)

  • 宽带红灯闪烁怎么回事(宽带红灯闪烁是什么意思)

    宽带红灯闪烁怎么回事(宽带红灯闪烁是什么意思)

  • 如何下载视频(如何下载视频素材)

    如何下载视频(如何下载视频素材)

  • 手机停机怎么恢复正常(手机停机怎么恢复移动网络)

    手机停机怎么恢复正常(手机停机怎么恢复移动网络)

  • 京东实名申诉要多久(京东被实名注册怎么办)

    京东实名申诉要多久(京东被实名注册怎么办)

  • 为什么浏览器打不开(为什么浏览器打开是百度网页)

    为什么浏览器打不开(为什么浏览器打开是百度网页)

  • 手机怎么知道照片多少k(手机怎么知道照片是多少K的)

    手机怎么知道照片多少k(手机怎么知道照片是多少K的)

  • word如何删除占位符(清除word)

    word如何删除占位符(清除word)

  • vivox20充电器多少瓦(vivox20a原装充电器多少钱)

    vivox20充电器多少瓦(vivox20a原装充电器多少钱)

  • 不用iTunes如何将iPhone和iPad同步到Mac?新版macOS Catalina升级方法汇总(不用itunes怎么下载软件)

    不用iTunes如何将iPhone和iPad同步到Mac?新版macOS Catalina升级方法汇总(不用itunes怎么下载软件)

  • 台式电脑组装过程详细图解(台式电脑组装过程视频)

    台式电脑组装过程详细图解(台式电脑组装过程视频)

  • phpcms与phpsso通信失败的解决方法(phpcms使用教程)

    phpcms与phpsso通信失败的解决方法(phpcms使用教程)

  • 投资性房地产出售时公允价值变动损益
  • 资产负债表的固定资产怎么计算
  • 筹建期印花税退税分录
  • 机器配件属于什么报销项目
  • 企业销售软件需要结转成本吗
  • 核定征收的收入总额包括营业外收入吗
  • 本月进项税大于销项税有留底,如何做会计分录
  • 销售之后发生销货折让收到红字发票如何做账?
  • 进口代收业务
  • 外购无形资产的相关税费包括增值税吗
  • 行政事业单位核算短期投资时有关预算会计核算正确的是
  • 委托加工农产品的扣除率
  • 递延收益在资产负债表哪里列示
  • 企业核税需要什么资料
  • 契税和车辆购置税的异同
  • 3%税率是一般纳税人还是小规模
  • 企业所得税必须要季度缴纳吗
  • 小规模纳税人转成一般纳税人条件
  • 建筑行业预算
  • 税控技术服务费计入什么科目
  • 增值税专票盖章盖在哪里
  • 企业的商誉会一直存在吗
  • 累计折旧余额怎样结转
  • 供应商赠送的原材料怎么做分录
  • 个税是什么意思必须交吗
  • 微pe工具箱怎么用
  • PHP:xml_set_object()的用法_XML解析器函数
  • fs209e是什么意思
  • PHP:imagepsloadfont()的用法_GD库图像处理函数
  • 银行借款利息支出计入什么科目
  • php如何实现
  • 在产品按定额成本计价法的特点
  • 阿查法拉亚盆地牡蛎
  • 应交税费会计分录例题
  • 红字发票步骤
  • php怎么写接口给别人调用
  • react的高阶组件理解
  • django pypi
  • linux 常用命令大全及其详解
  • springboot整合websocket怎么接受图片消息
  • php源码抓取工具
  • 分公司需要做纳税申报吗
  • 善意取得增值税专用发票
  • CentOS6.9下mysql 5.7.17安装配置方法图文教程
  • Apache RocketMQ 5.0 笔记
  • 房地产开发企业成本核算方法
  • 常用sql脚本
  • sqlserver2019删除
  • 固定资产加速折旧的方法有哪些
  • 个人所得税手续费返还账务处理
  • 发票冲红重开,重开时是按新税率还是旧税率?
  • 差旅费会计科目怎么做
  • 本年利润的会计分录怎么写
  • 融资租赁汽车怎么投诉电话
  • 计提工资的会计处理
  • 暂扣员工工资应怎么处理
  • 固定资产售后回购
  • 公司缴纳印花税如何缴纳
  • 政府预算年度
  • 公司发放给员工的福利又要回
  • 残疾人就业保障金怎么计算
  • 核定征收的小微企业
  • mysql数据库视频
  • windows7 设置
  • Mac上Parallels Desktop共享虚拟机怎么设置 Mac上Parallels Desktop共享虚拟机设置步
  • 新款苹果笔记本测评
  • win8计算机管理员权限
  • linux删除lun
  • cocos2d-x教程
  • 每日一个linux命令
  • vue.js有什么用
  • javascript教程chm
  • flash怎么测试当前场景
  • javascript基础笔记
  • 网上怎么交车船税
  • 运输费用抵扣税率最新规定
  • 吉林省耕地占用税实施办法
  • 客货两用车应如何运输
  • 动物大联盟是国内品牌吗
  • 重庆轨道第五轮19号线路
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设