位置: IT常识 - 正文

pytorch从零开始搭建神经网络(pytorch新手入门)

编辑:rootadmin
pytorch从零开始搭建神经网络 目录

基本流程

一、数据处理

二、模型搭建

三、定义代价函数&优化器

四、训练

附录

nn.Sequential

nn.Module

model.train() 和 model.eval() 

损失

图神经网络

基本流程《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili

推荐整理分享pytorch从零开始搭建神经网络(pytorch新手入门),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch入门教程(非常详细),pytorch 入门教程,pytorch入门到进阶,pytorch新手入门,pytorch 快速入门,pytorch零基础入门,pytorch新手入门,pytorch零基础入门,内容如对您有帮助,希望把文章链接给更多的朋友!

1. 数据预处理(Dataset、Dataloader)

2. 模型搭建(nn.Module)

3. 损失&优化(loss、optimizer)

4. 训练(forward、backward)

一、数据处理

对于数据处理,最为简单的⽅式就是将数据组织成为⼀个 。

但许多训练需要⽤到mini-batch,直 接组织成Tensor不便于我们操作。

pytorch为我们提供了Dataset和Dataloader两个类来方便的构建。

torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)

 

二、模型搭建

搭建一个简易的神经网络

除了采用pytorch自动梯度的方法来搭建神经网络,还可以通过构建一个继承了torch.nn.Module的新类,来完成forward和backward的重写。

# 神经网络搭建import torchfrom torch.autograd import Varible batch_n = 100 hidden_layer = 100 input_data = 1000output_data = 10 class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()def forward(self,input,w1,w2):x = torch.mm(input,w1)x = torch.clamp(x,min = 0)x = torch.mm(x,w2) def backward(self): passmodel = Model()#训练x = Variable(torch.randn(batch_n,input_data))

一点一点地看:

import torchdtype = torch.floatdevice = torch.device("cpu")N, D_in, H, D_out = 64, 1000, 100, 10# Create random input and output datax = torch.randn(N, D_in, device=device, dtype=dtype)y = torch.randn(N, D_out, device=device, dtype=dtype)# Randomly initialize weightsw1 = torch.randn(D_in, H, device=device, dtype=dtype)w2 = torch.randn(H, D_out, device=device, dtype=dtype)learning_rate = 1e-6

tensor 写一个粗糙版本(后面陆陆续续用pytorch提供的方法)

for t in range(500): # Forward pass: compute predicted y h = x.mm(w1) h_relu = h.clamp(min=0) y_pred = h_relu.mm(w2) # Compute and print loss loss = (y_pred - y).pow(2).sum().item() if t % 100 == 99: print(t, loss) # Backprop to compute gradients of w1 and w2 with respect to loss grad_y_pred = 2.0 * (y_pred - y) grad_w2 = h_relu.t().mm(grad_y_pred) grad_h_relu = grad_y_pred.mm(w2.t()) grad_h = grad_h_relu.clone() grad_h[h < 0] = 0 grad_w1 = x.t().mm(grad_h) # Update weights using gradient descent w1 -= learning_rate * grad_w1 w2 -= learning_rate * grad_w2三、定义代价函数&优化器

Autograd

for t in range(500): y_pred = x.mm(w1).clamp(min=0).mm(w2) loss = (y_pred - y).pow(2).sum() if t % 100 == 99: print(t, loss.item()) loss.backward() with torch.no_grad(): w1 -= learning_rate * w1.grad w2 -= learning_rate * w2.grad w1.grad.zero_() w2.grad.zero_()

对于需要计算导数的变量(w1和w2)创建时设定requires_grad=True,之后对于由它们参与计算的变量(例如loss),可以使用loss.backward()函数求出loss对所有requires_grad=True的变量的梯度,保存在w1.grad和w2.grad中。

在迭代w1和w2后,即使用完w1.grad和w2.grad后,使用zero_函数清空梯度。  

nn

model = torch.nn.Sequential( torch.nn.Linear(D_in, H), torch.nn.ReLU(), torch.nn.Linear(H, D_out),)loss_fn = torch.nn.MSELoss(reduction='sum')learning_rate = 1e-4for t in range(500): y_pred = model(x) loss = loss_fn(y_pred, y) if t % 100 == 99: print(t, loss.item()) model.zero_grad() loss.backward() with torch.no_grad(): for param in model.parameters(): param -= learning_rate * param.grad

optim

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)for t in range(500): y_pred = model(x) loss = loss_fn(y_pred, y) if t % 100 == 99: print(t, loss.item()) optimizer.zero_grad() loss.backward() optimizer.step()四、训练

迭代进行训练以及测试,其中训练的函数train里就保存了进行梯度下降求解的方法

# 定义训练函数,需要def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 从数据加载器中读取batch(一次读取多少张,即批次数),X(图片数据),y(图片真实标签)。 for batch, (X, y) in enumerate(dataloader): # 将数据存到显卡 X, y = X.to(device), y.to(device) # 得到预测的结果pred pred = model(X) # 计算预测的误差 # print(pred,y) loss = loss_fn(pred, y) # 反向传播,更新模型参数 optimizer.zero_grad() #梯度清零 loss.backward() #反向传播 optimizer.step() #更新参数 # 每训练10次,输出一次当前信息 if batch % 10 == 0: loss, current = loss.item(), batch * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

设置为测试模型并设置不计算梯度,进行测试数据集的加载,判断预测值与实际标签是否一致,统一正确信息个数

# 将模型转为验证模式model.eval()# 测试时模型参数不用更新,所以no_gard()with torch.no_grad(): # 加载数据加载器,得到里面的X(图片数据)和y(真实标签) for X, y in dataloader: 加载数据 pred = model(X)#进行预测 # 预测值pred和真实值y的对比 test_loss += loss_fn(pred, y).item() # 统计预测正确的个数 correct += (pred.argmax(1) == y).type(torch.float).sum().item()#返回相应维度的最大值的索引test_loss /= sizecorrect /= sizeprint(f"correct = {correct}, Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")附录

mark一下很有用的博客

pytorch代码编写入门 - 简书

推荐给大家!Pytorch编写代码基本步骤思想 - 知乎

用pytorch实现神经网络_徽先生的博客-CSDN博客_pytorch 神经网络

Dataset、DataLoader

① 创建一个 Dataset 对象 ② 创建一个 DataLoader 对象 ③ 循环这个 DataLoader 对象,将xx, xx加载到模型中进行训练

pytorch从零开始搭建神经网络(pytorch新手入门)

DataLoader详解_sereasuesue的博客-CSDN博客_dataloader

都会|可能会_深入浅出 Dataset 与 DataLoader

Pytorch加载自己的数据集(使用DataLoader读取Dataset)_l8947943的博客-CSDN博客_pytorch dataloader读取数据

可以直接调用的数据集

https://www.pianshen.com/article/9695297328/

nn.Sequential

pytorch教程之nn.Sequential类详解——使用Sequential类来自定义顺序连接模型_LoveMIss-Y的博客-CSDN博客_sequential类

nn.Module

torch.nn.Module是torch.nn.functional中方法的实例化

pytorch教程之nn.Module类详解——使用Module类来自定义模型_LoveMIss-Y的博客-CSDN博客_torch.nn.module

对应Sequential的三种包装方式,Module有三种写法

model.train() 和 model.eval() model.train()for epoch in range(epoch): for train_batch in train_loader: ... zhibiao = test(epoch, test_loader, model)def test(epoch, test_loader, model): model.eval() for test_batch in test_loader: ... return zhibiao

【Pytorch】model.train() 和 model.eval() 原理与用法_想变厉害的大白菜的博客-CSDN博客_pytorch train()

pytroch:model.train()、model.eval()的使用_像风一样自由的小周的博客-CSDN博客_model.train()放在程序的哪个位置

model = ...dataset = ...loss_fun = ...# traininglr=0.001model.train()for x,y in dataset: model.zero_grad() p = model(x) l = loss_fun(p, y) l.backward() for p in model.parameters(): p.data -= lr*p.grad# evaluatingsum_loss = 0.0model.eval()with torch.no_grad(): for x,y in dataset: p = model(x) l = loss_fun(p, y) sum_loss += lprint('total loss:', sum_loss)

https://www.jb51.net/article/211954.htm

损失

MAE:

import torchfrom torch.autograd import Variablex = Variable(torch.randn(100, 100))y = Variable(torch.randn(100, 100))loos_f = torch.nn.L1Loss()loss = loos_f(x,y)

MSE:

import torchfrom torch.autograd import Variablex = Variable(torch.randn(100, 100))y = Variable(torch.randn(100, 100))loos_f = torch.nn.MSELoss()#定义loss = loos_f(x, y)#调用

torch.nn中常用的损失函数及使用方法_加油上学人的博客-CSDN博客_nn损失函数

优化器

pytorch 优化器调参以及正确用法 - 简书

训练&测试

基于pytorch框架下的一个简单的train与test代码_黎明静悄悄啊的博客-CSDN博客

图神经网络

1. GCN、GAT

图神经网络及其Pytorch实现_jiangchao98的博客-CSDN博客_pytorch 图神经网络

2. 用DGL

PyTorch实现简单的图神经网络_梦家的博客-CSDN博客_pytorch图神经网络

一文看懂图神经网络GNN,及其在PyTorch框架下的实现(附原理+代码) - 知乎

图神经网络的不足

•扩展性差,因为训练时需要用到包含所有节点的邻接矩阵,是直推性的(transductive)

•局限于浅层,图神经网络只有两层

•不能作用于有向图

3. 用PyG

图神经网络框架-PyTorch Geometric(PyG)的使用__Old_Summer的博客-CSDN博客_pytorch-geometric

本文链接地址:https://www.jiuchutong.com/zhishi/300852.html 转载请保留说明!

上一篇:自动驾驶入门必须要学会的ADAS(详解)(自动驾驶科普)

下一篇:基于OC端的Bridge-API组件化应用(oc底层原理)

  • word怎么批量替换文字(word怎么批量替换不同内容)

    word怎么批量替换文字(word怎么批量替换不同内容)

  • 快手运费险怎么开通(快手运费险怎么越来越高)

    快手运费险怎么开通(快手运费险怎么越来越高)

  • vivox27淘宝深色模式怎么设置(手机淘宝有没有深色模式)

    vivox27淘宝深色模式怎么设置(手机淘宝有没有深色模式)

  • 任天堂日版和国行区别

    任天堂日版和国行区别

  • 微信文件助手里的文件怎么打印出来(微信文件助手里的图片如何全部拖出来)

    微信文件助手里的文件怎么打印出来(微信文件助手里的图片如何全部拖出来)

  • 快手有消息怎么不显示(快手有消息怎么没有提示音)

    快手有消息怎么不显示(快手有消息怎么没有提示音)

  • auto相机上的什么意思(相机auto好用吗)

    auto相机上的什么意思(相机auto好用吗)

  • 三星手机黑屏但有震动(三星手机黑屏但是触摸正常)

    三星手机黑屏但有震动(三星手机黑屏但是触摸正常)

  • miui开发版切换为稳定版会清除数据吗(miui开发版切换稳定版会清除数据吗)

    miui开发版切换为稳定版会清除数据吗(miui开发版切换稳定版会清除数据吗)

  • Word段后一行如何设置(word段后间距一行)

    Word段后一行如何设置(word段后间距一行)

  • 荣耀30和华为nova7有什么区别(荣耀30和华为nova5哪个好)

    荣耀30和华为nova7有什么区别(荣耀30和华为nova5哪个好)

  • 手机网页验证滑块不动(手机网页验证滑块拖不动)

    手机网页验证滑块不动(手机网页验证滑块拖不动)

  • 美拍登录不了怎么办(美拍登录不了怎么注销以前的账号)

    美拍登录不了怎么办(美拍登录不了怎么注销以前的账号)

  • ac220v50hz是什么意思(ac220v60hz什么意思)

    ac220v50hz是什么意思(ac220v60hz什么意思)

  • 华为如何关闭徕卡(华为mate30可以关闭徕卡)

    华为如何关闭徕卡(华为mate30可以关闭徕卡)

  • 安卓微信号怎么修改(安卓微信号怎么申请第二个)

    安卓微信号怎么修改(安卓微信号怎么申请第二个)

  • 华为手机设置彩虹电池(华为手机设置彩铃免费)

    华为手机设置彩虹电池(华为手机设置彩铃免费)

  • 京东需要确认收货吗(京东需要确认收货才能申请售后吗)

    京东需要确认收货吗(京东需要确认收货才能申请售后吗)

  • 1050ti配什么cpu(1050ti配什么cpu能打永劫)

    1050ti配什么cpu(1050ti配什么cpu能打永劫)

  • 如何关闭微信流量提醒(如何关闭微信流量月包)

    如何关闭微信流量提醒(如何关闭微信流量月包)

  • 华为nova5耳机插孔在哪里(华为nova5耳机插哪里)

    华为nova5耳机插孔在哪里(华为nova5耳机插哪里)

  • 支付宝怎么用手机号登录(支付宝怎么用手机号收款)

    支付宝怎么用手机号登录(支付宝怎么用手机号收款)

  • ie浏览器如何更新升级(ie浏览器如何更改下载位置)

    ie浏览器如何更新升级(ie浏览器如何更改下载位置)

  • 抖音点赞受限怎么解决(抖音点赞受限怎么解除限制)

    抖音点赞受限怎么解决(抖音点赞受限怎么解除限制)

  • vivox20泡到水怎么办(vivo手机被水泡了怎么办)

    vivox20泡到水怎么办(vivo手机被水泡了怎么办)

  • 5g电话是什么意思(5g是什么意思指的是手机还是手机卡)

    5g电话是什么意思(5g是什么意思指的是手机还是手机卡)

  • 怎么卸载魔秀主题(怎么卸载魔秀主页的软件)

    怎么卸载魔秀主题(怎么卸载魔秀主页的软件)

  • 织梦模板DEDECMS附加表自定义字段关联主表文章(织梦模板转讯睿模板)

    织梦模板DEDECMS附加表自定义字段关联主表文章(织梦模板转讯睿模板)

  • 非货币性资产交换补价大于25%的会计处理
  • 税务登记管理办法2023
  • 固定资产可以一次性摊销吗
  • 契税为什么计入成本
  • 红字发票和蓝字一样吗
  • 税务登记法人变更后多久生效
  • 二手车交易怎么办理过户手续
  • 库存盘盈盘亏按进价还是售价
  • 已认证的进项税额转出如何操作
  • 出现销项负数
  • 应收账款资产减值损失转回和核销的区别
  • 雇主责任险发票的项目名称怎么写
  • 小规模纳税人代理记账一年费用
  • 新个税过了申报期怎么办
  • 人工成本全额扣除吗
  • 企业收到政府扶贫资金补助及运用补助金怎么做账
  • 筹建期的工资
  • win10教育版用户账户控制怎么取消
  • 电脑下载的文件打不开怎么回事
  • 小规模纳税企业在应交增值税明细科目
  • 行政划拨无偿取得的土地使用权属于什么资产
  • 不是第三方的贷款app
  • 电脑很空但是占用率90
  • 计算土地增值税时增值额的扣除项目包括
  • 采购原材料合理化建议
  • PHP:mcrypt_enc_get_key_size()的用法_Mcrypt函数
  • 下脚料属于什么科目
  • 残疾人就业保障金
  • 华盛顿州帕卢斯心雕塑
  • 货物运输业增值税专用发票
  • 企业所得税征收方式有哪些?
  • 专家评审费可以由中标人支付吗
  • 巴塞罗那城市布局
  • 企业以付费的形式
  • 小规模纳税人在什么情况下会成为一般纳税人
  • 超过五年的未弥补亏损如何处理?
  • java线程的执行体
  • css设置旋转动画
  • mongodb连接数
  • 如何进行会计制度改革
  • 各种账簿的登记依据和登记方法分别是什么
  • 折旧啥意思
  • 其他应付款如何平账
  • 单位内部食堂怎么收费
  • 印花税每个月都计提吗
  • 不能抵扣的普通发票如何做分录
  • 个人银行存款要手续费吗
  • 影视公司临时演员怎么办
  • 相关损坏维修成本是什么
  • 直接减免税款的例子
  • 金蝶k3怎么新增会计科目
  • 超市的商品品种繁多琳琅满目
  • 税务开票系统如何设置不用重复登录
  • 金税盘买发票还要填交验旧表吗?
  • 新开企业去银行开户需要什么
  • 商品流通企业如何控成本
  • 为什么我们需要政府
  • insert into tbl() select * from tb2中加入多个条件
  • 查看mysql执行计划关键字
  • MacBook怎么恢复出厂设置
  • dos命令 新建文件
  • win10预览版21h2
  • apache1.3.19配置文件
  • ctl.start
  • centos8 systemd
  • win10更新后安装包会自动删除吗
  • centos7.5安装桌面
  • 文件夹底部显示
  • Win10 Mobile 10586正式版即将向Insider用户推送
  • linux中使用find命令查找文件
  • perl调用perl脚本
  • 简述jQuery ajax的执行顺序
  • 从零开始学什么
  • linux的gunzip命令
  • 命令最常用的类型有
  • 安卓手机管家
  • 增值税纳税申报表附列资料(一)
  • 青海国家税务局总局官网
  • 全资子公司和全资子企业的区别
  • 关于铁路安全的漫画
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设