位置: IT常识 - 正文

Pytorch深度学习实战3-6:详解网络骨架模块nn.Module(附实例)

编辑:rootadmin
原力计划Pytorch深度学习实战3-6:详解网络骨架模块nn.Module(附实例) 目录1 什么是nn.Module?2 从一个例子说起3 nn.Module主要方法4 自定义网络一般步骤1 什么是nn.Module?

推荐整理分享Pytorch深度学习实战3-6:详解网络骨架模块nn.Module(附实例),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

在实际应用过程中,经典网络结构(如卷积神经网络)往往不能满足我们的需求,因而大多数时候都需要自定义模型,比如:多输入多输出(MIMO)、多分支模型、跨层连接模型等。nn.Module就是Pytorch中用于自定义模型的核心方法。在Pytorch中,自定义层、自定义块、自定义模型,都是通过继承nn.Module类完成的。

nn.Module的定义如下

class Module(object): def __init__(self): def forward(self, *input): def __call__(self, *input, **kwargs): def parameters(self, recurse=True): def named_parameters(self, prefix='', recurse=True): def children(self): def named_children(self): def modules(self): def named_modules(self, memo=None, prefix=''): def train(self, mode=True): def eval(self): def zero_grad(self):...

注意:自定义网络需要继承nn.Module类,并重点实现上面的构造函数__init__构造函数和forward()这两个方法。

2 从一个例子说起

下面是一个自定义感知机的实例

# 感知机class Perception(nn.Module): def __init__(self, inDim, hidDim, outDim): super(Perception, self).__init__() self.perception = nn.Sequential( nn.Linear(inDim, hidDim), nn.Sigmoid(), nn.Linear(hidDim, outDim), nn.Sigmoid() ) def forward(self, x): return self.perception(x)

测试模块

perception = Perception(5,20,10)print(perception(torch.Tensor([1,2,3,4,5]))) # 自动调用forward()前向传播

其中nn.Sequential()可以序列化封装若干个相连的组件,在希望快速搭建模型且无需考虑中间过程的情形下,推荐使用nn.Sequential()进行局部模块化。

Pytorch深度学习实战3-6:详解网络骨架模块nn.Module(附实例)

从上面的实例可以看出:

一般把网络中的特定结构(如全连接层、卷积层等)以序列的形式放在构造函数__init__()中将模型自定义的各个层的连接关系和数据通路设计放在forward()函数中,以实现模型功能并保证数据结构正常不具有可学习参数的层(如ReLU、dropout、BatchNormanation层等)可并入__init__()内部的某个层,或在forward()函数中进行层间连接

库nn.functional同样提供了大量网络模块和组件,与nn.Module类不同在于其更偏向底层——nn.Module封装了对学习参数的维护,更注重模型结构;nn.functional需要手动指定参数和结构,例如下面线性模型Linear的核心源码,其前向过程仍然调用了底层的nn.functional实现。

class Linear(Module): def __init__(self, in_features: int, out_features: int) -> None: super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) self.bias = Parameter(torch.Tensor(out_features)) def forward(self, input: Tensor) -> Tensor: return F.linear(input, self.weight, self.bias)

一般在设计通过已有nn.Module无法组装的网络结构时,可以调用底层的nn.functional实现;或是存在无需优化学习参数的结构(如损失函数、激活函数等),可以调用nn.functional(即作为单纯函数使用)避免实例化nn.Module,轻量化网络

# 使用nn.Module需要实例化后调用lossFunc = nn.CrossEntropyLoss()loss = lossFunc(output, label)# 使用nn.functional则只作为函数即可loss = F.cross_entropy(output, label)3 nn.Module主要方法

nn.Module的主要属性与方法列举如表所示。

序号属性/方法含义1forward()模型前向传播2train()训练模式3eval()评估模式4named_parameters()返回模型各可学习参数的名称和参数组成的列表5parameters()返回模型各可学习参数组成的列表6children()返回一个迭代器,其中每个元素是Sequential序列类型,可以使用下标索引来进一步获取每一个Sequenrial里面的具体层,比如conv层、dense层等7named_children()返回一个迭代器,其中每个元素是一个二元组,第一元是名称,第二元是该名称对应的层或Sequential序列4 自定义网络一般步骤

自定义网络一般步骤总结如下:

自定义一个继承自Module的类实现构造函数_init__,在其中参数化网络层,比如卷积神经网络的卷积核大小、池化层尺寸,全连接网络的输入输出大小等;实现前向传播forward()接口,定义网络的连接情况或其他运算方式(如向量拼接、向量变维、数据处理等)

下面再给出一个卷积神经网络的实例加深理解

class CNN(nn.Module): def __init__(self): super().__init__() self.convPoolLayer_1 = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5), nn.MaxPool2d(kernel_size=2), nn.ReLU() ) self.convPoolLayer_2 = nn.Sequential( nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5), nn.MaxPool2d(kernel_size=2), nn.ReLU() ) self.fcLayer = nn.Linear(320, 10) def __str__(self) -> str: return "cnn_model" def forward(self, x): batchSize = x.size(0) x = self.convPoolLayer_1(x) x = self.convPoolLayer_2(x) x = x.reshape(batchSize, -1) x = self.fcLayer(x) return x


🔥 更多精彩专栏:

《ROS从入门到精通》《Pytorch深度学习实战》《机器学习强基计划》《运动规划实战精讲》…

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

本文链接地址:https://www.jiuchutong.com/zhishi/297615.html 转载请保留说明!

上一篇:【Vant Weapp】van-tab 标签页(vant weapp官方文档)

下一篇:Django--基于Python的Web应用框架(django pycharm)

  • 2021淘宝天猫双十一活动什么时候开始(2021淘宝天猫双十一销售额)

    2021淘宝天猫双十一活动什么时候开始(2021淘宝天猫双十一销售额)

  • mysql 查看版本(mysql如何查看版本)

    mysql 查看版本(mysql如何查看版本)

  • word怎么把图片居中(word怎么把图片放在中间)

    word怎么把图片居中(word怎么把图片放在中间)

  • 抖音极速版的购物车在哪里找出来(抖音极速版的购物通知怎么删除)

    抖音极速版的购物车在哪里找出来(抖音极速版的购物通知怎么删除)

  • wps怎么设置自动保存(wps怎么设置自动播放)

    wps怎么设置自动保存(wps怎么设置自动播放)

  • 计算机按其性能分为哪五大类(计算机按其性能规模速度和功能等可分为什么)

    计算机按其性能分为哪五大类(计算机按其性能规模速度和功能等可分为什么)

  • 1的二进制是多少(1的二进制是什么)

    1的二进制是多少(1的二进制是什么)

  • 苹果8辅助触控自动消失(苹果8辅助触控老是自己跳出来)

    苹果8辅助触控自动消失(苹果8辅助触控老是自己跳出来)

  • qq怎么换聊天位置左右(如何在qq聊天界面更换聊天背景)

    qq怎么换聊天位置左右(如何在qq聊天界面更换聊天背景)

  • 电动车充电器发热正常吗(电动车充电器发热充不进去电是什么原因)

    电动车充电器发热正常吗(电动车充电器发热充不进去电是什么原因)

  • 为什么拍的照片是横着的(为什么拍的照片比本人难看)

    为什么拍的照片是横着的(为什么拍的照片比本人难看)

  • accmeta_vod是什么文件夹(acdm是什么意思)

    accmeta_vod是什么文件夹(acdm是什么意思)

  • 因特网的拓扑结构是什么(因特网的拓扑结构是一种什么结构)

    因特网的拓扑结构是什么(因特网的拓扑结构是一种什么结构)

  • 电脑写文章用什么软件(电脑写文章用什么软件可以配图)

    电脑写文章用什么软件(电脑写文章用什么软件可以配图)

  • word页面字符怎么设置(word文档页面设置字符数和行数)

    word页面字符怎么设置(word文档页面设置字符数和行数)

  • 怎么改淘宝背景图片(怎么修改淘宝背景颜色)

    怎么改淘宝背景图片(怎么修改淘宝背景颜色)

  • qq面对面快传的视频在哪里(qq面对面快传的视频在相册找不到)

    qq面对面快传的视频在哪里(qq面对面快传的视频在相册找不到)

  • 物联卡实名了怎么解除(物联卡实名怎么实名)

    物联卡实名了怎么解除(物联卡实名怎么实名)

  • 笔记本电脑广告太多怎么办(SONY笔记本电脑广告)

    笔记本电脑广告太多怎么办(SONY笔记本电脑广告)

  • 小米cc9有红外遥控吗(小米cc9红外遥控在哪里打开)

    小米cc9有红外遥控吗(小米cc9红外遥控在哪里打开)

  • 小米8能用多久(小米8能用多久充电)

    小米8能用多久(小米8能用多久充电)

  • 华为钱包门禁卡用不了(华为钱包门禁卡模拟成功但是打不开)

    华为钱包门禁卡用不了(华为钱包门禁卡模拟成功但是打不开)

  • 抖音怎么保存静态壁纸(抖音怎么保存静态视频)

    抖音怎么保存静态壁纸(抖音怎么保存静态视频)

  • 华为畅享9plus隐藏功能怎么使用(华为畅享9plus隐私空间在哪)

    华为畅享9plus隐藏功能怎么使用(华为畅享9plus隐私空间在哪)

  • 摄像头离线怎么恢复(摄像头离线怎么重新连接wifi)

    摄像头离线怎么恢复(摄像头离线怎么重新连接wifi)

  • macos big sur状态栏怎么显示键盘亮度?(macos big sur卡在)

    macos big sur状态栏怎么显示键盘亮度?(macos big sur卡在)

  • 空调安装发票税率
  • 商业承兑过期后多久失效
  • 税盘锁了还能报税吗
  • 残疾人保障金做什么会计科目
  • 委托加工物资属于在产品吗
  • 小规模增值税做那个费用科目
  • 铁路运输发票的开具要求
  • 预缴企业所得税会计处理
  • 土地增值税清算规程实施细则
  • 补缴企业所得税滞纳金账务处理
  • 小微企业计算公式
  • 转让无形资产所有权计入什么科目
  • 所得税亏损财务处理办法
  • 增值税专用发票怎么开
  • 凯利公式实战
  • 事业单位发放的工作经费计入哪个科目
  • 汽车进项税额
  • 企业所得税发票虚假成本调减当年的吗
  • 怎么恢复系统win10
  • 购买所有物品都是可以退货吗
  • Windows10如何解压rar
  • 王者荣耀怎么解除关系
  • .linux文件
  • 广告费的会计科目
  • wedp是什么文件
  • 格洛利亚酒店
  • php中自定义常量的函数是
  • 公司亏损注销了怎么处理
  • thinkphp模糊查询
  • 公司现金发放证明
  • 个体工商户经济类型是内资吗
  • 农产品免税发票可以抵扣增值税吗
  • 新会计准则增加了哪些科目
  • 微信转账要如何退回去
  • python中numpy数组和列表的区别
  • linux db2安装与配置
  • mysql 小时差
  • 织梦cms为什么不维护了
  • sql2008附加数据库
  • 利用java实现计算器
  • 进项大于销项的会计分录怎么做?
  • 预付卡的增值税处理
  • 捐赠的增值税可以抵扣吗
  • 小规模未达起征点销售额是多少
  • 购进国内交通运输产品
  • 办公桌椅入什么会计科目
  • 应收账款收不回来
  • 季节性停工损失计入存货成本吗
  • 注册资本及构成
  • 投标保证金会计科目
  • 公司是生产企业现在要开出租赁的发票可以开吗
  • 分公司办事处需要什么手续
  • 收到微信公众号反诈骗风险提示
  • 购房增值税发票是购房发票吗
  • 增值税专用发票丢了怎么补救
  • 现金日记账月末怎么结账图片
  • 如何设置银行存款日记账
  • sql server连接不上服务器怎么办
  • mysql多表查询方式
  • win10 更新 蓝屏
  • rhel6安装
  • QQExternal.exe是什么进程?QQExternal.exe进程为什么被运行?
  • linux系统中常用的五种文件类型
  • xp系统怎么调性能
  • cocos2djs
  • perl怎么把字符串变为数字
  • python科学绘图
  • android/bitmap.h
  • beautiful python
  • js中的?
  • 基于unity的游戏开发
  • js实现框选
  • js简单网速测试方法
  • 税控盘状态
  • 车船税每年多少钱
  • 北京朝阳地税局电话号码
  • 国税局招录条件
  • 什么叫发票信息对比
  • 云旅游存在的问题及解决措施
  • 省国税局领导由谁任命
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设