位置: IT常识 - 正文

pytorch如何搭建一个最简单的模型,(pytorch如何搭建神经网络)

编辑:rootadmin
pytorch如何搭建一个最简单的模型, 一、搭建模型的步骤

推荐整理分享pytorch如何搭建一个最简单的模型,(pytorch如何搭建神经网络),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch搭建gan,pytorch环境搭建mac,pytorch如何搭建神经网络,pytorch怎么装,pytorch创建模型,pytorch 搭建简单网络,pytorch搭建gan,pytorch搭建gan,内容如对您有帮助,希望把文章链接给更多的朋友!

在 PyTorch 中,可以使用 torch.nn 模块来搭建深度学习模型。具体步骤如下:

定义一个继承自 torch.nn.Module 的类,这个类将作为我们自己定义的模型。

在类的构造函数 __init__() 中定义网络的各个层和参数。可以使用 torch.nn 模块中的各种层,如 Conv2d、BatchNorm2d、Linear 等。

在类中定义前向传播函数 forward(),实现模型的具体计算过程。

将模型部署到 GPU 上,可以使用 model.to(device) 将模型移动到指定的 GPU 设备上。

二、简单的例子pytorch如何搭建一个最简单的模型,(pytorch如何搭建神经网络)

下面是一个简单的例子,演示了如何使用 torch.nn 模块搭建一个简单的全连接神经网络:

import torch.nn as nnclass MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(784, 512) self.relu = nn.ReLU() self.fc2 = nn.Linear(512, 10) def forward(self, x): x = x.view(-1, 784) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x

MyNet 的神经网络类,它继承自 torch.nn.Module。在构造函数 __init__() 中定义了两个全连接层,一个 ReLU 激活函数,并将它们作为网络的成员变量。在前向传播函数 forward() 中,首先将输入的图像数据 x 压成一维向量,然后依次经过两个全连接层和一个 ReLU 激活函数,最终得到模型的输出结果。

在模型训练之前,需要将模型部署到 GPU 上,可以使用以下代码将模型移动到 GPU 上:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = MyNet().to(device)如何将loss函数添加到模型中去呢?

在 PyTorch 中,通常将损失函数作为单独的对象来定义,并在训练过程中手动计算和优化损失。为了将损失函数添加到模型中,需要在模型类中添加一个成员变量,然后在前向传播函数中计算损失。

下面是一个例子,演示了如何在模型中添加交叉熵损失函数:

import torch.nn as nnclass MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(784, 512) self.relu = nn.ReLU() self.fc2 = nn.Linear(512, 10) self.loss_fn = nn.CrossEntropyLoss() def forward(self, x, y): x = x.view(-1, 784) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) loss = self.loss_fn(x, y) return x, loss

在模型类 MyNet 的构造函数中添加了一个成员变量 self.loss_fn,它是交叉熵损失函数。在前向传播函数 forward() 中,传入两个参数 x 和 y,其中 x 是输入图像数据,y 是对应的标签。在函数中先执行正向传播计算,然后计算交叉熵损失,并将损失值作为输出返回。

实际训练代码

在实际训练过程中,首先将模型输出结果 x 和标签 y 传入前向传播函数 forward() 中计算损失,然后使用优化器更新模型的权重和偏置。代码如下:

model = MyNet()optimizer = torch.optim.SGD(model.parameters(), lr=0.01)for inputs, labels in data_loader: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs, loss = model(inputs, labels) loss.backward() optimizer.step()

在上面的代码中,使用随机梯度下降优化器 torch.optim.SGD 来更新模型的参数。在每个批次中,首先将输入数据和标签移动到 GPU 上,然后使用 optimizer.zero_grad() 将梯度清零。接着执行前向传播计算,并得到损失值 loss。最后使用 loss.backward() 计算梯度并执行反向传播,使用 optimizer.step() 更新模型参数。

2023.03.27更新 完整的代码

# -*-coding:utf-8-*-# !/usr/bin/env python# @Time : 2023/3/27 上午11:00# @Author : loveinfall uestc# @File : csdn_test_.py# @Description :import torchimport torch.nn as nnimport torch.utils.data as dataimport cv2####################### model ###########################class MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(784, 512) self.relu = nn.ReLU() self.fc2 = nn.Linear(512, 10) def forward(self, x): x = x.view(-1, 784) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x###################### end ############################################### loss 函数 #############################loss_fn = nn.CrossEntropyLoss()################## end #################################################### dataloader 需要自己构建 ############class image_folder(data.Dataset): def __init__(self): self.image_dirs = []#构造数据读取路径列表 self.label_dirs = [] def __getitem__(self,index): image = cv2.imread(self.image_dirs[index]) label = 'read data'#根据实际情况,写 return image,label def __len__(self): return 'len(data)'train_dataset = image_folder()data_loader = data.DataLoader( train_dataset, batch_size=3, shuffle=True, num_workers=2, pin_memory=True)#################### end ##################################################### train #######################@#####device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = MyNet().to(device)optimizer = torch.optim.SGD(model.parameters(), lr=0.01)for inputs, labels in data_loader: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_fn(outputs,labels) loss.backward() optimizer.step()
本文链接地址:https://www.jiuchutong.com/zhishi/297834.html 转载请保留说明!

上一篇:【Vue】图片拉近、全屏背景实战经验总结(vue图片点击放大)

下一篇:什么是前后端分离(什么是前后端分离的方式)

  • 花呗可以分24期吗(花呗可以分24期买苹果电脑吗)

    花呗可以分24期吗(花呗可以分24期买苹果电脑吗)

  • word页眉横线去掉方法是什么(word文档页眉横线去掉)

    word页眉横线去掉方法是什么(word文档页眉横线去掉)

  • 手机后盖原装和非原装的区别(手机后盖原装和原装区别)

    手机后盖原装和非原装的区别(手机后盖原装和原装区别)

  • 手机屏幕边缘黑色阴影(手机屏幕边缘黑线)

    手机屏幕边缘黑色阴影(手机屏幕边缘黑线)

  • 音响声道2.1和5.1区别(音响声道2.1和5.1和7.1)

    音响声道2.1和5.1区别(音响声道2.1和5.1和7.1)

  • 华为应用助手不见了(华为应用助手不小心卸载了)

    华为应用助手不见了(华为应用助手不小心卸载了)

  • 滴滴异地可以接单吗(滴滴异地接单的规定)

    滴滴异地可以接单吗(滴滴异地接单的规定)

  • 更新系统时遇到卡米怎么办答案(系统更新遇到错误什么原因)

    更新系统时遇到卡米怎么办答案(系统更新遇到错误什么原因)

  • 抖音老照片修复在哪里(抖音老照片修复怎么弄)

    抖音老照片修复在哪里(抖音老照片修复怎么弄)

  • 淘宝退款对买家有影响吗(淘宝退款买家未举证怎么处理)

    淘宝退款对买家有影响吗(淘宝退款买家未举证怎么处理)

  • 怎么在电脑上写电子版的文章(番茄小说怎么在电脑上写)

    怎么在电脑上写电子版的文章(番茄小说怎么在电脑上写)

  • 买的新手机老卡怎么回事(新买手机老是卡怎么办)

    买的新手机老卡怎么回事(新买手机老是卡怎么办)

  • 腾讯会议如何录视频(腾讯会议如何录播课程)

    腾讯会议如何录视频(腾讯会议如何录播课程)

  • qq友谊的巨轮几天能掉(qq的友谊的巨轮有个数限制吗)

    qq友谊的巨轮几天能掉(qq的友谊的巨轮有个数限制吗)

  • 华为p20内存不够用怎么办(华为p20pro内存不够)

    华为p20内存不够用怎么办(华为p20pro内存不够)

  • 1t机械硬盘是什么意思(1T机械硬盘是什么)

    1t机械硬盘是什么意思(1T机械硬盘是什么)

  • itunes store是什么

    itunes store是什么

  • 电脑键盘灯亮没反应是怎么回事(电脑键盘灯亮没声音)

    电脑键盘灯亮没反应是怎么回事(电脑键盘灯亮没声音)

  • 电源和负载的本质区别(电源和负载的参考方向)

    电源和负载的本质区别(电源和负载的参考方向)

  • 多媒体系统只能在微机上运行吗(多媒体系统的常用设备有什么)

    多媒体系统只能在微机上运行吗(多媒体系统的常用设备有什么)

  • 怎么退出多页面视图(怎么退出多页面视图模式)

    怎么退出多页面视图(怎么退出多页面视图模式)

  • 气喘吁吁的意思(气喘吁吁的意思三年级)

    气喘吁吁的意思(气喘吁吁的意思三年级)

  • 如何下载虾米音乐mp3(虾米音乐怎样免费下载)

    如何下载虾米音乐mp3(虾米音乐怎样免费下载)

  • 微信收不了红包怎么回事(不添加银行卡微信收不了红包)

    微信收不了红包怎么回事(不添加银行卡微信收不了红包)

  • mac钥匙串密码忘记了(macos钥匙串密码忘记)

    mac钥匙串密码忘记了(macos钥匙串密码忘记)

  • 小米8慢动作只能拍10秒(小米8慢动作怎么不见了)

    小米8慢动作只能拍10秒(小米8慢动作怎么不见了)

  • 所得税税前扣除项目及扣除标准
  • 增值税纳税申报时间
  • 一般纳税人一直零申报会降为小规模吗
  • 人工费用的核算例题
  • 简易征收开出去的票可以抵扣吗
  • 资产负债表中未交税金负数表示什么
  • 未预缴开票
  • 发票丢失说明怎么填写
  • 计提坏账准备的做法体现了什么的信息质量要求
  • 税控系统维护费账务处理
  • 业务有提成个税怎么扣
  • 小规模纳税人没有成本票怎么做账
  • 地税的发票
  • 民办学校都没有编制吗
  • 工会经费什么时候返还给企业
  • 小规模纳税人如何计算增值税
  • 两处拿工资的缴税问题
  • 房地产转让的条件
  • 建筑业主营业务成本包括哪些
  • 用产品抵债的合同怎么写
  • React developer tools调试工具全网最新最全安装教程
  • 坏账损失税务处理
  • vue绑定css样式
  • 购买股票的会计科目
  • axios.defaults.baseURL的三种配置方法
  • Win11 Build 23430 预览版发布(附更新修复内容汇总)
  • HTML 事件参考手册
  • 微信随机红包表情包怎么弄
  • 新申报是什么
  • 文化事业建设费的征收范围
  • mysql左连接查询 效率
  • 公司车辆出售要交多少税
  • 个人所得税专项附加扣除赡养老人
  • 金税四期的特点
  • 收回投资收到的现金减少
  • 付国外专利费用需办什么手续
  • 外贸出口增值税附表二填哪项
  • 无形资产界定
  • 境外运费支付属什么费用
  • 主营业务成本如何设置明细
  • 公司给员工发福利图片
  • 商贸企业发出商品怎么确认收入
  • 国地税合并对个人带来的影响
  • 宾馆收入怎么做账
  • 营业外收入如何纳税
  • 解除劳动合同的合法程序
  • 人力公司开的代驾发票
  • 汽车三产件
  • 股权转让的会计分录
  • 双倍余额递减法最后两年怎么算
  • 会计总账怎么做账
  • 报关单位分为几种类型?其业务范围有何不同?
  • 快速插入大量数据的asp.net代码(Sqlserver)
  • mysql的join有几种
  • mysql在cmd命令操作
  • win8的应用商店在哪
  • win8安装虚拟机的步骤
  • win8系统的运行在哪里打开
  • xp剪贴板怎么打开
  • vista和win7哪个对配置要求高
  • ubuntu系统鼠标没反应
  • linux所谓的free
  • fedora phpMyAdmin 安装方法及介绍
  • linux route -n命令结果详解
  • 拒绝远程操作
  • linux系统fedora
  • 安卓手机屏幕不好使了怎么办
  • 一个项目引多个项目
  • 使用的拼音
  • 国际安卓应用市场
  • 用批处理删除注册表项
  • input输入@弹出框
  • unity3d课程
  • 扩展坞哪个牌子比较好
  • eclipse怎么查看项目的位置
  • 深入理解android卷1 pdf
  • jquery22插件网
  • before和after在句子中怎么翻译
  • 成都税务怎么查询社保缴费记录
  • 企业补缴公积金 归集额增加
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设