位置: IT常识 - 正文

pytorch如何搭建一个最简单的模型,(pytorch如何搭建神经网络)

编辑:rootadmin
pytorch如何搭建一个最简单的模型, 一、搭建模型的步骤

推荐整理分享pytorch如何搭建一个最简单的模型,(pytorch如何搭建神经网络),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch搭建gan,pytorch环境搭建mac,pytorch如何搭建神经网络,pytorch怎么装,pytorch创建模型,pytorch 搭建简单网络,pytorch搭建gan,pytorch搭建gan,内容如对您有帮助,希望把文章链接给更多的朋友!

在 PyTorch 中,可以使用 torch.nn 模块来搭建深度学习模型。具体步骤如下:

定义一个继承自 torch.nn.Module 的类,这个类将作为我们自己定义的模型。

在类的构造函数 __init__() 中定义网络的各个层和参数。可以使用 torch.nn 模块中的各种层,如 Conv2d、BatchNorm2d、Linear 等。

在类中定义前向传播函数 forward(),实现模型的具体计算过程。

将模型部署到 GPU 上,可以使用 model.to(device) 将模型移动到指定的 GPU 设备上。

二、简单的例子pytorch如何搭建一个最简单的模型,(pytorch如何搭建神经网络)

下面是一个简单的例子,演示了如何使用 torch.nn 模块搭建一个简单的全连接神经网络:

import torch.nn as nnclass MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(784, 512) self.relu = nn.ReLU() self.fc2 = nn.Linear(512, 10) def forward(self, x): x = x.view(-1, 784) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x

MyNet 的神经网络类,它继承自 torch.nn.Module。在构造函数 __init__() 中定义了两个全连接层,一个 ReLU 激活函数,并将它们作为网络的成员变量。在前向传播函数 forward() 中,首先将输入的图像数据 x 压成一维向量,然后依次经过两个全连接层和一个 ReLU 激活函数,最终得到模型的输出结果。

在模型训练之前,需要将模型部署到 GPU 上,可以使用以下代码将模型移动到 GPU 上:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = MyNet().to(device)如何将loss函数添加到模型中去呢?

在 PyTorch 中,通常将损失函数作为单独的对象来定义,并在训练过程中手动计算和优化损失。为了将损失函数添加到模型中,需要在模型类中添加一个成员变量,然后在前向传播函数中计算损失。

下面是一个例子,演示了如何在模型中添加交叉熵损失函数:

import torch.nn as nnclass MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(784, 512) self.relu = nn.ReLU() self.fc2 = nn.Linear(512, 10) self.loss_fn = nn.CrossEntropyLoss() def forward(self, x, y): x = x.view(-1, 784) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) loss = self.loss_fn(x, y) return x, loss

在模型类 MyNet 的构造函数中添加了一个成员变量 self.loss_fn,它是交叉熵损失函数。在前向传播函数 forward() 中,传入两个参数 x 和 y,其中 x 是输入图像数据,y 是对应的标签。在函数中先执行正向传播计算,然后计算交叉熵损失,并将损失值作为输出返回。

实际训练代码

在实际训练过程中,首先将模型输出结果 x 和标签 y 传入前向传播函数 forward() 中计算损失,然后使用优化器更新模型的权重和偏置。代码如下:

model = MyNet()optimizer = torch.optim.SGD(model.parameters(), lr=0.01)for inputs, labels in data_loader: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs, loss = model(inputs, labels) loss.backward() optimizer.step()

在上面的代码中,使用随机梯度下降优化器 torch.optim.SGD 来更新模型的参数。在每个批次中,首先将输入数据和标签移动到 GPU 上,然后使用 optimizer.zero_grad() 将梯度清零。接着执行前向传播计算,并得到损失值 loss。最后使用 loss.backward() 计算梯度并执行反向传播,使用 optimizer.step() 更新模型参数。

2023.03.27更新 完整的代码

# -*-coding:utf-8-*-# !/usr/bin/env python# @Time : 2023/3/27 上午11:00# @Author : loveinfall uestc# @File : csdn_test_.py# @Description :import torchimport torch.nn as nnimport torch.utils.data as dataimport cv2####################### model ###########################class MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(784, 512) self.relu = nn.ReLU() self.fc2 = nn.Linear(512, 10) def forward(self, x): x = x.view(-1, 784) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x###################### end ############################################### loss 函数 #############################loss_fn = nn.CrossEntropyLoss()################## end #################################################### dataloader 需要自己构建 ############class image_folder(data.Dataset): def __init__(self): self.image_dirs = []#构造数据读取路径列表 self.label_dirs = [] def __getitem__(self,index): image = cv2.imread(self.image_dirs[index]) label = 'read data'#根据实际情况,写 return image,label def __len__(self): return 'len(data)'train_dataset = image_folder()data_loader = data.DataLoader( train_dataset, batch_size=3, shuffle=True, num_workers=2, pin_memory=True)#################### end ##################################################### train #######################@#####device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = MyNet().to(device)optimizer = torch.optim.SGD(model.parameters(), lr=0.01)for inputs, labels in data_loader: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_fn(outputs,labels) loss.backward() optimizer.step()
本文链接地址:https://www.jiuchutong.com/zhishi/297834.html 转载请保留说明!

上一篇:【Vue】图片拉近、全屏背景实战经验总结(vue图片点击放大)

下一篇:什么是前后端分离(什么是前后端分离的方式)

  • 开网店营销推广有什么技巧?(开网店怎么推广运营-经验问答)

    开网店营销推广有什么技巧?(开网店怎么推广运营-经验问答)

  • 苹果照片小组件怎么设置喜欢的照片(苹果照片小组件为什么显示无可用内容)

    苹果照片小组件怎么设置喜欢的照片(苹果照片小组件为什么显示无可用内容)

  • 抖音转发给朋友的顺序怎么删除(抖音转发给朋友的顺序是怎么来的)

    抖音转发给朋友的顺序怎么删除(抖音转发给朋友的顺序是怎么来的)

  • 抖音视频通话有美颜吗

    抖音视频通话有美颜吗

  • 投屏不是全屏怎么办(投屏不是全屏怎么设置)

    投屏不是全屏怎么办(投屏不是全屏怎么设置)

  • 远程上班是什么意思(远程工作有什么优点)

    远程上班是什么意思(远程工作有什么优点)

  • 荣耀体脂秤为何只显示体重(荣耀体脂秤不亮怎么回事)

    荣耀体脂秤为何只显示体重(荣耀体脂秤不亮怎么回事)

  • 屏幕花了是内屏坏了吗(屏幕花了是内屏还是外屏坏了)

    屏幕花了是内屏坏了吗(屏幕花了是内屏还是外屏坏了)

  • 小米6屏幕刷新率(小米6屏幕刷新率在哪里设置)

    小米6屏幕刷新率(小米6屏幕刷新率在哪里设置)

  • 电信4g首选网络类型(电信4g首选网络是什么)

    电信4g首选网络类型(电信4g首选网络是什么)

  • 大数据处理流程的第一步是(大数据处理流程的步骤)

    大数据处理流程的第一步是(大数据处理流程的步骤)

  • qq号很久没用会不会被注销(qq很久没用会变其他头像吗)

    qq号很久没用会不会被注销(qq很久没用会变其他头像吗)

  • ml4lte是什么版本手机(ml4lte现在价多少钱)

    ml4lte是什么版本手机(ml4lte现在价多少钱)

  • ipad耗电快是什么原因(ipad耗电非常快)

    ipad耗电快是什么原因(ipad耗电非常快)

  • aloo是华为什么型号(lio aloo是华为什么型号)

    aloo是华为什么型号(lio aloo是华为什么型号)

  • huawei share关不掉(huawei share没反应)

    huawei share关不掉(huawei share没反应)

  • 快手实验室没有上下滑动怎么设置(快手实验室没有k歌功能)

    快手实验室没有上下滑动怎么设置(快手实验室没有k歌功能)

  • oppoa83禁止安装怎么解除(oppo禁止安装程序,要怎么设置才能解除?)

    oppoa83禁止安装怎么解除(oppo禁止安装程序,要怎么设置才能解除?)

  • keyword.exe是什么进程 有什么用 keyword进程查询(key是什么文档)

    keyword.exe是什么进程 有什么用 keyword进程查询(key是什么文档)

  • Win11 Dev 预览版 22483更新发布推送(附完整更新内容)(win10dev预览版)

    Win11 Dev 预览版 22483更新发布推送(附完整更新内容)(win10dev预览版)

  • HTTP 协议

    HTTP 协议

  • 实现瀑布流布局的四种方法(瀑布流实现方式)

    实现瀑布流布局的四种方法(瀑布流实现方式)

  • phpcms不允许上传该类型文件怎么办(php.ini上传限制)

    phpcms不允许上传该类型文件怎么办(php.ini上传限制)

  • 产权转移书据印花税政策
  • 金税三期的主要系统
  • 新个税税率法
  • 已经认证抵扣的发票还能作废吗
  • 个体户转一般纳税人怎么做账
  • 发票用完了领发票需要带什么东西
  • 万元版和十万元版可以一起用吗
  • 购买不良资产交印花税吗
  • 发票可用时间
  • 库存现金贷方为负数说明什么
  • 一般纳税人销项开普票,进项票可以抵扣吗
  • 商铺售后回租会计处理
  • 调整跨期收入是否调增值税
  • 收入跨期审计调整分录如何滚调
  • 收到银行存款怎么记账
  • 购买投资理财产品放的会计处理怎么做?
  • 分公司亏损还会分摊所得税吗
  • 小规模增值税缴纳怎么算
  • 劳务公司差额征收税率是多少
  • 办公室装修费用计入什么会计科目
  • 财务会计中关于坏账损失的账务处理
  • 公允价值变动损益在利润表哪里
  • 进项税额及存货减值
  • win11打开软件出现????????
  • 长期借款利息费用的资本化账务处理
  • php如何运行脚本
  • 什么是增值税差额征税政策的小规模纳税人
  • 场外期权会计核算
  • linux命令大全详解
  • 退休后工作单位填什么内容
  • php echo语句
  • 医院产生的相关法律法规
  • php模板教程
  • php计算时间
  • Centos6.5和Centos7 php环境搭建方法
  • html中写php
  • 网站为什么需要备案
  • 上年度固定资产少入账了怎么办?
  • 一般纳税人销售自己使用过的汽车
  • 增值税专用发票和普通发票的区别
  • 记账凭证的主要作用有
  • 主营营业成本会计分录
  • 缴纳上年汇算清缴的分录
  • 小规模纳税人可以做进出口贸易吗
  • 小微企业报税是多久报一次
  • 口罩属于什么经济分类
  • 固定资产加速折旧计算方法
  • 贷款担保费应计入什么
  • 行政事业单位应用方案总账,财务分析
  • 库存现金账实不符怎么处理
  • 高速公路过路费怎么算的
  • 逾期未认证的增值税发票处理办法
  • 扣除工程款说明
  • 建立固定资产管理台账
  • 支付给烟农的价格怎么算
  • cmd提示符基础知识
  • 老毛桃u盘启动制作工具如何把原来的win7改xp系统图文教程
  • 台式电脑NUM LOCK键还能亮,算不算死机了
  • freebsd挂载ntfs
  • .exe是什么软件
  • Linux服务器管理的开机界面
  • linux批处理文件怎么写
  • windows7可以打开多个窗口
  • 32位系统的电脑可以连接打印机吗
  • Win10 Mobile RedStone预览版14267已知问题与修复内容汇总
  • 基于jQuery中ajax的相关方法汇总(必看篇)
  • cocos3.0
  • android 自定义
  • 批处理判断一个文件是否存在
  • unity导出3d模型
  • 深入理解计算机系统
  • 完美解决雷电模拟器卡顿
  • jquery on()
  • python selectfrommodel
  • 辽宁省国家税务局官网
  • 国税补录信息怎么查询
  • 上海交电费户号8位数
  • 工资薪金的税收金额是填实际发生还是帐载金额
  • 湖南省税务举报
  • 注册一个信息咨询公司需要什么
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设