位置: IT常识 - 正文

pytorch如何搭建一个最简单的模型,(pytorch如何搭建神经网络)

编辑:rootadmin
pytorch如何搭建一个最简单的模型, 一、搭建模型的步骤

推荐整理分享pytorch如何搭建一个最简单的模型,(pytorch如何搭建神经网络),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch搭建gan,pytorch环境搭建mac,pytorch如何搭建神经网络,pytorch怎么装,pytorch创建模型,pytorch 搭建简单网络,pytorch搭建gan,pytorch搭建gan,内容如对您有帮助,希望把文章链接给更多的朋友!

在 PyTorch 中,可以使用 torch.nn 模块来搭建深度学习模型。具体步骤如下:

定义一个继承自 torch.nn.Module 的类,这个类将作为我们自己定义的模型。

在类的构造函数 __init__() 中定义网络的各个层和参数。可以使用 torch.nn 模块中的各种层,如 Conv2d、BatchNorm2d、Linear 等。

在类中定义前向传播函数 forward(),实现模型的具体计算过程。

将模型部署到 GPU 上,可以使用 model.to(device) 将模型移动到指定的 GPU 设备上。

二、简单的例子pytorch如何搭建一个最简单的模型,(pytorch如何搭建神经网络)

下面是一个简单的例子,演示了如何使用 torch.nn 模块搭建一个简单的全连接神经网络:

import torch.nn as nnclass MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(784, 512) self.relu = nn.ReLU() self.fc2 = nn.Linear(512, 10) def forward(self, x): x = x.view(-1, 784) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x

MyNet 的神经网络类,它继承自 torch.nn.Module。在构造函数 __init__() 中定义了两个全连接层,一个 ReLU 激活函数,并将它们作为网络的成员变量。在前向传播函数 forward() 中,首先将输入的图像数据 x 压成一维向量,然后依次经过两个全连接层和一个 ReLU 激活函数,最终得到模型的输出结果。

在模型训练之前,需要将模型部署到 GPU 上,可以使用以下代码将模型移动到 GPU 上:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = MyNet().to(device)如何将loss函数添加到模型中去呢?

在 PyTorch 中,通常将损失函数作为单独的对象来定义,并在训练过程中手动计算和优化损失。为了将损失函数添加到模型中,需要在模型类中添加一个成员变量,然后在前向传播函数中计算损失。

下面是一个例子,演示了如何在模型中添加交叉熵损失函数:

import torch.nn as nnclass MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(784, 512) self.relu = nn.ReLU() self.fc2 = nn.Linear(512, 10) self.loss_fn = nn.CrossEntropyLoss() def forward(self, x, y): x = x.view(-1, 784) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) loss = self.loss_fn(x, y) return x, loss

在模型类 MyNet 的构造函数中添加了一个成员变量 self.loss_fn,它是交叉熵损失函数。在前向传播函数 forward() 中,传入两个参数 x 和 y,其中 x 是输入图像数据,y 是对应的标签。在函数中先执行正向传播计算,然后计算交叉熵损失,并将损失值作为输出返回。

实际训练代码

在实际训练过程中,首先将模型输出结果 x 和标签 y 传入前向传播函数 forward() 中计算损失,然后使用优化器更新模型的权重和偏置。代码如下:

model = MyNet()optimizer = torch.optim.SGD(model.parameters(), lr=0.01)for inputs, labels in data_loader: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs, loss = model(inputs, labels) loss.backward() optimizer.step()

在上面的代码中,使用随机梯度下降优化器 torch.optim.SGD 来更新模型的参数。在每个批次中,首先将输入数据和标签移动到 GPU 上,然后使用 optimizer.zero_grad() 将梯度清零。接着执行前向传播计算,并得到损失值 loss。最后使用 loss.backward() 计算梯度并执行反向传播,使用 optimizer.step() 更新模型参数。

2023.03.27更新 完整的代码

# -*-coding:utf-8-*-# !/usr/bin/env python# @Time : 2023/3/27 上午11:00# @Author : loveinfall uestc# @File : csdn_test_.py# @Description :import torchimport torch.nn as nnimport torch.utils.data as dataimport cv2####################### model ###########################class MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(784, 512) self.relu = nn.ReLU() self.fc2 = nn.Linear(512, 10) def forward(self, x): x = x.view(-1, 784) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x###################### end ############################################### loss 函数 #############################loss_fn = nn.CrossEntropyLoss()################## end #################################################### dataloader 需要自己构建 ############class image_folder(data.Dataset): def __init__(self): self.image_dirs = []#构造数据读取路径列表 self.label_dirs = [] def __getitem__(self,index): image = cv2.imread(self.image_dirs[index]) label = 'read data'#根据实际情况,写 return image,label def __len__(self): return 'len(data)'train_dataset = image_folder()data_loader = data.DataLoader( train_dataset, batch_size=3, shuffle=True, num_workers=2, pin_memory=True)#################### end ##################################################### train #######################@#####device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = MyNet().to(device)optimizer = torch.optim.SGD(model.parameters(), lr=0.01)for inputs, labels in data_loader: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_fn(outputs,labels) loss.backward() optimizer.step()
本文链接地址:https://www.jiuchutong.com/zhishi/297834.html 转载请保留说明!

上一篇:【Vue】图片拉近、全屏背景实战经验总结(vue图片点击放大)

下一篇:什么是前后端分离(什么是前后端分离的方式)

  • macos catalina是什么意思

    macos catalina是什么意思

  • 电脑黑屏怎么办(电脑黑屏怎么办主机还是亮的)

    电脑黑屏怎么办(电脑黑屏怎么办主机还是亮的)

  • 微信腾讯投票是否匿名(微信腾讯投票怎么看结果)

    微信腾讯投票是否匿名(微信腾讯投票怎么看结果)

  • 苹果录屏横屏变成直屏怎么办(苹果屏幕录制横屏)

    苹果录屏横屏变成直屏怎么办(苹果屏幕录制横屏)

  • 国行三网通什么意思(国行和三网通的区别)

    国行三网通什么意思(国行和三网通的区别)

  • 荣耀10x支持wifi6吗(荣耀10x支持多少瓦快充)

    荣耀10x支持wifi6吗(荣耀10x支持多少瓦快充)

  • 苹果电脑耳机孔在哪(苹果电脑耳机孔图片)

    苹果电脑耳机孔在哪(苹果电脑耳机孔图片)

  • qq分享屏幕能看见对方吗(QQ分享屏幕能看到脸吗)

    qq分享屏幕能看见对方吗(QQ分享屏幕能看到脸吗)

  • 电脑显示屏自带摄像头吗(电脑显示屏自带声音吗)

    电脑显示屏自带摄像头吗(电脑显示屏自带声音吗)

  • 手机报警怎么回事(手机报警有用吗)

    手机报警怎么回事(手机报警有用吗)

  • 怎么删除qq里的自动回复(怎么删除qq里的小世界功能)

    怎么删除qq里的自动回复(怎么删除qq里的小世界功能)

  • mac连wifi密码无效wp2(mac wifi连接上不能上网怎么办)

    mac连wifi密码无效wp2(mac wifi连接上不能上网怎么办)

  • nova6 5g是什么系统(nova 6是不是5g手机)

    nova6 5g是什么系统(nova 6是不是5g手机)

  • 微博身份证使用次数超限(微博身份证使用次数超限怎么解决)

    微博身份证使用次数超限(微博身份证使用次数超限怎么解决)

  • 手机无声音是什么原因,插耳机就有(手机无声音咋办)

    手机无声音是什么原因,插耳机就有(手机无声音咋办)

  • ipad属于平板电脑吗(ipad属于电器吗)

    ipad属于平板电脑吗(ipad属于电器吗)

  • 苹果8死机了怎么重启(iphone 8死机)

    苹果8死机了怎么重启(iphone 8死机)

  • 苹果11要不要升级13.1.2

    苹果11要不要升级13.1.2

  • cad字高怎么设置(cad中字高怎么设置)

    cad字高怎么设置(cad中字高怎么设置)

  • 小米体重秤怎么连接米家(小米体重秤怎么调公斤和斤)

    小米体重秤怎么连接米家(小米体重秤怎么调公斤和斤)

  • 群玩助手对方能察觉吗(群玩助手一直能定位吗)

    群玩助手对方能察觉吗(群玩助手一直能定位吗)

  • 将U盘制作成安装LION、WIN7系统盘,方便MacBook Air没有光驱下安装双系统(怎样将u盘制作成电脑系统启动盘?)

    将U盘制作成安装LION、WIN7系统盘,方便MacBook Air没有光驱下安装双系统(怎样将u盘制作成电脑系统启动盘?)

  • Windows 10带来了新的触摸键盘体验,而且非常棒(win 10有什么用)

    Windows 10带来了新的触摸键盘体验,而且非常棒(win 10有什么用)

  • 电脑每次开机都要按f1解决方法(电脑每次开机都要按f1怎么解决)

    电脑每次开机都要按f1解决方法(电脑每次开机都要按f1怎么解决)

  • 同一商品税收分类编码不一样
  • 小规模纳税人月收入多少免征增值税
  • 金税盘清卡怎么统计税额
  • 食品增值税专用发票可以退税吗?
  • 事业单位固定资产入账标准最新规定
  • 一般纳税人临时工工资怎么入账
  • 其他应收款有什么业务
  • 红字发票信息表是销方还是购方开
  • 无法收回的应收款计入什么科目
  • 外经证预缴附加税
  • 开给个人的通讯费发票能下账吗
  • 景区门票入什么费用
  • 增值税主要有三种类型
  • 机票退票手续费为什么这么贵
  • 企业所得税是否有利于调节产业结构
  • 本月收到外汇怎么做账
  • 票面税费和实际上税为什么不一样
  • 在建工程暂估入库的账务处理
  • 固定资产原值减预计净残值等于什么
  • 母公司计提子公司投资收益
  • 工商注销债务承担
  • 教育预收费
  • 长时间不操作电动座椅会发生什么
  • 无形资产入账摊销
  • 发票认证如何认证
  • php文件包含的4种方式
  • 电脑中毒如何处理
  • 外籍人员个人所得税政策2023规定
  • 公众号 隐藏文章
  • 劳务费可以抵扣进项吗
  • 代扣增值税如何做账
  • 汇兑损益计入什么科目
  • 从零开始文章
  • 后处理作用
  • php+flash+jQuery多图片上传源码分享
  • mysql全局锁和表锁
  • 扶贫资金入股问题
  • 结转费用类会计分录怎么写
  • 建筑公司遇到的问题
  • 在sqlserver2008中
  • 帝国cms如何做网站
  • 所有者权益的确认依附于什么的确认
  • 停工期间工资支付标准
  • 公司什么项目
  • 企业所得税的扣除是什么意思
  • 不能防止sql注入
  • 通行费抵扣进项税怎么做账
  • 收到汇算清缴退回的税款如何做账
  • 材料合理损耗计入入账价值吗
  • 物业公司劳务外包
  • 电费可以计入营业外收入吗
  • 资源税代扣代缴取消时间
  • 日常生活中各种形式的能量的转化
  • 购买无形资产的手续费计入
  • 小规模纳税人注册资金要求多少
  • 居间活动费用由谁负担
  • 销售红酒的公司
  • 房地产企业会计科目
  • sqlserver 中ntext字段的批量替换(updatetext的用法)
  • mysql免安装版怎么使用
  • 服务器时间和电脑时间
  • 手动GHOST安装系统方法教程图解
  • 搜狗浏览器ie8
  • mac显示隐藏文件夹
  • centos7(core)
  • vsftpd配置用户登录目录
  • linux 追踪
  • 在windowsxp的应用程序中,经常有一些菜单选项呈暗灰色
  • win8计算器在哪里找
  • win7如何限制网速
  • win7怎么查是不是正版
  • iframe移动端自适应
  • linux压缩tar文件命令
  • 深入解读2023年一号文件
  • 手游 unity
  • python 二叉堆
  • javascript学习指南
  • 重庆市网上税务局官网
  • 三种人不交个人所得税?
  • 加强监督管理工作
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设