位置: IT常识 - 正文

Pytorch深度学习实战3-6:详解网络骨架模块nn.Module(附实例)

编辑:rootadmin
原力计划Pytorch深度学习实战3-6:详解网络骨架模块nn.Module(附实例) 目录1 什么是nn.Module?2 从一个例子说起3 nn.Module主要方法4 自定义网络一般步骤1 什么是nn.Module?

推荐整理分享Pytorch深度学习实战3-6:详解网络骨架模块nn.Module(附实例),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

在实际应用过程中,经典网络结构(如卷积神经网络)往往不能满足我们的需求,因而大多数时候都需要自定义模型,比如:多输入多输出(MIMO)、多分支模型、跨层连接模型等。nn.Module就是Pytorch中用于自定义模型的核心方法。在Pytorch中,自定义层、自定义块、自定义模型,都是通过继承nn.Module类完成的。

nn.Module的定义如下

class Module(object): def __init__(self): def forward(self, *input): def __call__(self, *input, **kwargs): def parameters(self, recurse=True): def named_parameters(self, prefix='', recurse=True): def children(self): def named_children(self): def modules(self): def named_modules(self, memo=None, prefix=''): def train(self, mode=True): def eval(self): def zero_grad(self):...

注意:自定义网络需要继承nn.Module类,并重点实现上面的构造函数__init__构造函数和forward()这两个方法。

2 从一个例子说起

下面是一个自定义感知机的实例

# 感知机class Perception(nn.Module): def __init__(self, inDim, hidDim, outDim): super(Perception, self).__init__() self.perception = nn.Sequential( nn.Linear(inDim, hidDim), nn.Sigmoid(), nn.Linear(hidDim, outDim), nn.Sigmoid() ) def forward(self, x): return self.perception(x)

测试模块

perception = Perception(5,20,10)print(perception(torch.Tensor([1,2,3,4,5]))) # 自动调用forward()前向传播

其中nn.Sequential()可以序列化封装若干个相连的组件,在希望快速搭建模型且无需考虑中间过程的情形下,推荐使用nn.Sequential()进行局部模块化。

Pytorch深度学习实战3-6:详解网络骨架模块nn.Module(附实例)

从上面的实例可以看出:

一般把网络中的特定结构(如全连接层、卷积层等)以序列的形式放在构造函数__init__()中将模型自定义的各个层的连接关系和数据通路设计放在forward()函数中,以实现模型功能并保证数据结构正常不具有可学习参数的层(如ReLU、dropout、BatchNormanation层等)可并入__init__()内部的某个层,或在forward()函数中进行层间连接

库nn.functional同样提供了大量网络模块和组件,与nn.Module类不同在于其更偏向底层——nn.Module封装了对学习参数的维护,更注重模型结构;nn.functional需要手动指定参数和结构,例如下面线性模型Linear的核心源码,其前向过程仍然调用了底层的nn.functional实现。

class Linear(Module): def __init__(self, in_features: int, out_features: int) -> None: super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) self.bias = Parameter(torch.Tensor(out_features)) def forward(self, input: Tensor) -> Tensor: return F.linear(input, self.weight, self.bias)

一般在设计通过已有nn.Module无法组装的网络结构时,可以调用底层的nn.functional实现;或是存在无需优化学习参数的结构(如损失函数、激活函数等),可以调用nn.functional(即作为单纯函数使用)避免实例化nn.Module,轻量化网络

# 使用nn.Module需要实例化后调用lossFunc = nn.CrossEntropyLoss()loss = lossFunc(output, label)# 使用nn.functional则只作为函数即可loss = F.cross_entropy(output, label)3 nn.Module主要方法

nn.Module的主要属性与方法列举如表所示。

序号属性/方法含义1forward()模型前向传播2train()训练模式3eval()评估模式4named_parameters()返回模型各可学习参数的名称和参数组成的列表5parameters()返回模型各可学习参数组成的列表6children()返回一个迭代器,其中每个元素是Sequential序列类型,可以使用下标索引来进一步获取每一个Sequenrial里面的具体层,比如conv层、dense层等7named_children()返回一个迭代器,其中每个元素是一个二元组,第一元是名称,第二元是该名称对应的层或Sequential序列4 自定义网络一般步骤

自定义网络一般步骤总结如下:

自定义一个继承自Module的类实现构造函数_init__,在其中参数化网络层,比如卷积神经网络的卷积核大小、池化层尺寸,全连接网络的输入输出大小等;实现前向传播forward()接口,定义网络的连接情况或其他运算方式(如向量拼接、向量变维、数据处理等)

下面再给出一个卷积神经网络的实例加深理解

class CNN(nn.Module): def __init__(self): super().__init__() self.convPoolLayer_1 = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5), nn.MaxPool2d(kernel_size=2), nn.ReLU() ) self.convPoolLayer_2 = nn.Sequential( nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5), nn.MaxPool2d(kernel_size=2), nn.ReLU() ) self.fcLayer = nn.Linear(320, 10) def __str__(self) -> str: return "cnn_model" def forward(self, x): batchSize = x.size(0) x = self.convPoolLayer_1(x) x = self.convPoolLayer_2(x) x = x.reshape(batchSize, -1) x = self.fcLayer(x) return x


🔥 更多精彩专栏:

《ROS从入门到精通》《Pytorch深度学习实战》《机器学习强基计划》《运动规划实战精讲》…

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

本文链接地址:https://www.jiuchutong.com/zhishi/297615.html 转载请保留说明!

上一篇:【Vant Weapp】van-tab 标签页(vant weapp官方文档)

下一篇:Django--基于Python的Web应用框架(django pycharm)

  • 腾讯视频可以投屏到电视吗(腾讯视频可以投屏到微光吗)

    腾讯视频可以投屏到电视吗(腾讯视频可以投屏到微光吗)

  • 戴尔推出G3223Q和G3223D两款32吋游戏显示器哪个好一点(戴尔g3 2842)

    戴尔推出G3223Q和G3223D两款32吋游戏显示器哪个好一点(戴尔g3 2842)

  • 苹果手机打电话的时候没有网络(苹果手机打电话没声音)

    苹果手机打电话的时候没有网络(苹果手机打电话没声音)

  • 优酷同时登录几个设备

    优酷同时登录几个设备

  • 短信sp信息费是什么(c网短信sp内容费)

    短信sp信息费是什么(c网短信sp内容费)

  • 如何修改抖音上的作品(如何修改抖音上的IP)

    如何修改抖音上的作品(如何修改抖音上的IP)

  • 全民k歌的道具值几分(全民k歌的道具是怎么来的)

    全民k歌的道具值几分(全民k歌的道具是怎么来的)

  • 打电话关机打qq对方网络良好(打电话关机打微信视频显示对方忙线中怎么回事)

    打电话关机打qq对方网络良好(打电话关机打微信视频显示对方忙线中怎么回事)

  • 华为p40可以放大几倍(华为p40pro如何放大一百倍)

    华为p40可以放大几倍(华为p40pro如何放大一百倍)

  • 开关电源适配器是干什么用的(开关电源适配器能当充电器用吗)

    开关电源适配器是干什么用的(开关电源适配器能当充电器用吗)

  • 苹果手机序列号dn开头是什么版本(苹果手机序列号是哪个)

    苹果手机序列号dn开头是什么版本(苹果手机序列号是哪个)

  • 支持ipad pencil的ipad有哪些(支持pencil的平板)

    支持ipad pencil的ipad有哪些(支持pencil的平板)

  • 淘宝买东西用了优惠券退款了还会回来吗(淘宝买东西用了花呗但是支付宝没有花呗可以还款吗)

    淘宝买东西用了优惠券退款了还会回来吗(淘宝买东西用了花呗但是支付宝没有花呗可以还款吗)

  • 申请大王卡不去领可以吗(申请大王卡不去会扣钱吗)

    申请大王卡不去领可以吗(申请大王卡不去会扣钱吗)

  • 打包安装程序可以删吗(打包安装程序有用吗)

    打包安装程序可以删吗(打包安装程序有用吗)

  • 苹果11买回来有钢化膜吗(苹果11买回来有钢化膜吗?)

    苹果11买回来有钢化膜吗(苹果11买回来有钢化膜吗?)

  • ppt可以竖版吗(ppt可以用竖版吗)

    ppt可以竖版吗(ppt可以用竖版吗)

  • 电容麦克风和动圈麦克风的区别(电容麦克风和动圈麦克风哪个贵)

    电容麦克风和动圈麦克风的区别(电容麦克风和动圈麦克风哪个贵)

  • 红米k30有呼吸灯吗(红米k30有呼吸灯么)

    红米k30有呼吸灯吗(红米k30有呼吸灯么)

  • 怎么查找微信里的消费记录(怎么查找微信里的文件)

    怎么查找微信里的消费记录(怎么查找微信里的文件)

  • 苹果壁纸怎么设置缩小(苹果壁纸怎么设置景深效果)

    苹果壁纸怎么设置缩小(苹果壁纸怎么设置景深效果)

  • 服务器拒绝访问的原因(服务器拒绝访问,企业许可无效怎么回事)

    服务器拒绝访问的原因(服务器拒绝访问,企业许可无效怎么回事)

  • mmusbkb2.exe是什么进程 有什么作用 mmusbkb2进程查询(mmc.exe是什么进程)

    mmusbkb2.exe是什么进程 有什么作用 mmusbkb2进程查询(mmc.exe是什么进程)

  • 用PyCharm配置PyQt5:一键实现ui文件转py文件(在pycharm中配置python)

    用PyCharm配置PyQt5:一键实现ui文件转py文件(在pycharm中配置python)

  • 分公司可以享受企业所得税优惠吗
  • 以前年度多计收入多交税怎么处理
  • 工会经费计税
  • 运输业月末进销项税怎么结转
  • 失控发票补税可以抵扣吗
  • 公司收入是否应优先支付工资
  • 居民企业参股外国企业信息报告填写
  • 政府补贴是什么职能
  • 按价格从价缴纳增值税
  • 企业购买硬件与软件该如何做账?
  • 个体工商户有公户吗?
  • 装修公司能开增值发票吗
  • 住宅租给公司办什么手续
  • 做天猫合理避税吗
  • 少交了增值税怎么补
  • 业务人员出差住宿费记什么科目
  • 福利费专票不抵税可以吗
  • 建筑总包会计分录
  • 增资印花税增加哪个税目
  • 纳税人拒不缴滞纳金是否可单独强制执行
  • 收到赞助费开什么发票
  • 存货资产减值损失转回怎么做账
  • iphone6s怎么开启开发者选项
  • 用自产的产品用于生产线
  • 如何编制处置固定资产
  • 网速不稳定的解决方法
  • 小规模纳税人的企业所得税怎么算
  • 增值税发票要审核成功才能开吗
  • rtlcpl.exe
  • react img onerror
  • php语言之面向对象编程 educoder
  • wifi万能钥匙密码王
  • 蒙特城堡干红葡萄酒价格
  • 构造二叉排序树代码
  • 抄税的步骤
  • python中列表清空
  • 资产支出加权平均数例题
  • 什么是会计确认的基础
  • 非限定性资产和业务活动表关系
  • php cms
  • 将织梦dedecms转换到wordpress
  • 个人所得税其他扣除300一个月
  • 季度对账单 怎么处理
  • sqlserver2005简介
  • 子公司注销合并报表少数股东权益的处理
  • 预收账款是怎么算的
  • 上月计提多了怎么办
  • 开多少平方超市赚钱
  • 股东借款的利息收入
  • 增值税退税账务处理,经其他收益科目
  • 有限责任公司注册要求
  • 其他综合收益如何计算所得税
  • 合同负债包括
  • 服装厂委托物资怎么写
  • 收到工程款怎么入账
  • MySQL 5.6 中 TIMESTAMP 的变化分析
  • 数据库表的行数
  • 创建的sql语句
  • xp禁用win键
  • ubuntu for lot
  • macbook怎么关闭设置上的1
  • windows7开机启动
  • windows8快速启动设置
  • linux系统硬盘分区类型
  • libproj.so
  • js的上传文件
  • 合并多个js文件
  • nodejs搭建web服务器
  • 微信小程序实现轮播图
  • android判断横竖屏
  • mysql源码安装和二进制安装
  • 处理判断字符串是否相等
  • linux中的shell命令
  • android 多个权限合并 弹窗
  • js输出表格
  • 电子税务推广工作内容
  • 出口退税的汇率按什么时候的汇率
  • 加油河南app怎么注销
  • 一般纳税人企业所得税5%还是25%
  • 全国税收总收入完成59260.61
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设