位置: IT常识 - 正文

pytorch初学笔记(六):DataLoader的使用(pytorch入门教程(非常详细))

编辑:rootadmin
pytorch初学笔记(六):DataLoader的使用

目录

一、DataLoader介绍

1. DataLoader作用

2. 常用参数介绍 

二、DataLoader的使用

1. 导入并实例化DataLoader

2. 具体使用

2.1 数据集中数据的读取

2.2 DataLoader中数据的读取

3. 使用tensorboard可视化效果

3.1 改变batchsize 

3.2 改变drop_last

3.3 改变shuffle


一、DataLoader介绍1. DataLoader作用

推荐整理分享pytorch初学笔记(六):DataLoader的使用(pytorch入门教程(非常详细)),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch零基础,pytorch教程,pytorch零基础入门,pytorch 入门,pytorch 快速入门,pytorch零基础入门,pytorch零基础入门,pytorch 入门教程,内容如对您有帮助,希望把文章链接给更多的朋友!

        DataLoader是一个可迭代的数据装载器,组合了数据集和采样器,并在给定数据集上提供可迭代对象。可以完成对数据集中多个对象的集成。

2. 常用参数介绍 

torch.utils.data — PyTorch 1.13 documentation

CLASS  DataLoader

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, 

batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, 

drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, 

generator=None, *, prefetch_factor=2, persistent_workers=False, pin_memory_device='') 

先导概念介绍:

Epoch: 所有训练样本都已输入到模型中,称为一个epochIteration: 一批样本(batch_size)输入到模型中,称为一个Iteration,Batchsize: 一批样本的大小, 决定一个epoch有多少个Iteration

常用的主要有以下五个参数:

dataset(数据集):需要提取数据的数据集,Dataset对象batch_size(批大小):每一次装载样本的个数,int型 shuffle(洗牌):进行新一轮epoch时是否要重新洗牌,Boolean型num_workers:是否多进程读取机制drop_last:当样本数不能被batchsize整除时, 是否舍弃最后一批数据

二、DataLoader的使用

        我们使用CIFAR10的测试数据集来完成DataLoader的使用。

1. 导入并实例化DataLoader

        创建一个dataloader,设置批大小为4,每一个epoch重新洗牌,不进行多进程读取机制,不舍弃不能被整除的批次。

#导入数据集的包import torchvision.datasets#导入dataloader的包from torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriter#创建测试数据集test_dataset = torchvision.datasets.CIFAR10(root="./CIRFA10",train=False,transform=torchvision.transforms.ToTensor())#创建一个dataloader,设置批大小为4,每一个epoch重新洗牌,不进行多进程读取机制,不舍弃不能被整除的批次test_dataloader = DataLoader(dataset=test_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=False)2. 具体使用2.1 数据集中数据的读取

        由于数据集中的数据已经被我们转换成了tensor型,我们用dataset[0]输出第一张图片,使用shape属性输出tensor类型的大小,target代表图片的标签。 

img,target = test_dataset[0]print(img.shape,target)

        可以看到图片有RGB3个通道,大小为32*32,target为3。

2.2 DataLoader中数据的读取

        在dataset中,每一个对象元组由一张图片对象img和一个标签target组成;

        而dataloader中会分别对一个批次中的图片和标签进行打包,因此dataloader中,每一个对象由元组由batchsize张图片对象imgs和batchsize个标签targets组成。

对一个批次中的所有图片对象进行打包,形成一个对象,我们叫它imgs对一个批次中所有的标签进行打包,形成一个对象,我们叫它targets

        我们需要通过for循环来取出loader中的对象,loader中的对象个数=数据集中对象个数/batch_size,故应为10000/4=2500个对象。

        核心代码:

for data in test_dataloader: imgs,targets = data print(imgs.shape) print(targets) #导入数据集的包import torchvision.datasets#导入dataloader的包from torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriter#创建测试数据集test_dataset = torchvision.datasets.CIFAR10(root="./CIRFA10",train=False,transform=torchvision.transforms.ToTensor())#创建一个dataloader,设置批大小为4,每一个epoch重新洗牌,不进行多进程读取机制,不舍弃不能被整除的批次test_dataloader = DataLoader(dataset=test_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=False)#测试数据集中第一张图片对象img,target = test_dataset[0]print(img.shape,target)#打印数据集中图片数量print(len(test_dataset))#loader中对象for data in test_dataloader: imgs,targets = data print(imgs.shape) print(targets)#dataloader中对象个数print(len(test_dataloader))

        loader中的对象格式:

imgs的维度变成了4*3*32*32,即四张图片,每张图片3个通道,每张图片大小为32*32。targets里有4个target,分别是四张图片的target。pytorch初学笔记(六):DataLoader的使用(pytorch入门教程(非常详细))

       loader中的对象个数:

        2500个,数据集中图片个数为10000,10000/4=2500,验证正确。说明loader中数据按4个一组打包。 

3. 使用tensorboard可视化效果3.1 改变batchsize 

        修改数据集的batchsize为64,writer中调用的方法为add_images(),因为需要读取的图片有多张。

#导入数据集的包import torchvision.datasets#导入dataloader的包from torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriter#创建测试数据集test_dataset = torchvision.datasets.CIFAR10(root="./CIRFA10",train=False,transform=torchvision.transforms.ToTensor())#创建一个dataloader,设置批大小为64,每一个epoch重新洗牌,不进行多进程读取机制,不舍弃不能被整除的批次test_dataloader = DataLoader(dataset=test_dataset,batch_size=64,shuffle=True,num_workers=0,drop_last=False)writer = SummaryWriter("log")#loader中对象step = 0for data in test_dataloader: imgs,targets = data writer.add_images("loader",imgs,step) step+=1writer.close()

结果如下所示,可以看到一个step中有64张图片。

        但是我们发现step156时只取了16张图片,是因为10000张图片每次取64张是不能整除的,因此最后剩下了16张,单独放在最后一个step中,对最后剩余数量的图片进行保留是因为我们设置的drop_last=False。

 

3.2 改变drop_last

        如果我们改变drop_last=True,则不会保留最后的16张图片,会被舍弃,只保留能被整除的批次。

 

        结果如下所示,可以看到最后一步为155步,没了最后的16张图片,只保留了所有能整除的64的step。 

3.3 改变shuffle

        每一轮epoch之后就是分配完了一次数据,而shuffle决定了是否在新一轮epoch开始时打乱所有图片的属性进行分配。

        在代码中epoch就是最外层的循环,假设我们的epoch=2,即需要分配两次数据:

shuffle=TRUE代表第一轮循环结束后会打乱数据集中所有图片的顺序重新进行分配。shuffle=FALSE代表第一轮循环结束后不打乱数据集中所有图片的顺序,还是按原顺序进行分配。

3.3.1 shuffle=False时

#导入数据集的包import torchvision.datasets#导入dataloader的包from torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriter#创建测试数据集test_dataset = torchvision.datasets.CIFAR10(root="./CIRFA10",train=False,transform=torchvision.transforms.ToTensor())#创建一个dataloader,设置批大小为64,每一个epoch重新洗牌,不进行多进程读取机制,不舍弃不能被整除的批次test_dataloader = DataLoader(dataset=test_dataset,batch_size=64,shuffle=False,num_workers=0,drop_last=True)writer = SummaryWriter("log")#loader中对象for epoch in range(2): step = 0 for data in test_dataloader: imgs, targets = data writer.add_images("Epoch:{}".format(epoch), imgs, step) step += 1writer.close()

        可以看到epoch=0和epoch=1的每一个step中的图片都是分配的相同的,说明每一轮大循环开始前没有在数据集中重新打乱顺序。

3.3.2 shuffle=True时

       可以看到epoch=0和epoch=1的每一个step中的图片不同了,说明每一轮大循环开始前都在数据集中重新打乱了顺序。

参考资料 

系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)_翻滚的小@强的博客-CSDN博客_dataloader读取顺序

DataLoader的使用_哔哩哔哩_bilibili 

本文链接地址:https://www.jiuchutong.com/zhishi/289781.html 转载请保留说明!

上一篇:如何自己搭建一个ai画图系统? 从0开始云服务器部署novelai(如何自己搭建一个邮箱服务器)

下一篇:【UML】-- 顺序图练习题含答案(自动售货机、学生选课、提款机、购买地铁票、洗衣机工作)(uml中的顺序图由什么组成)

  • 小米蓝牙音箱如何配对(小米蓝牙音箱如何连接)

    小米蓝牙音箱如何配对(小米蓝牙音箱如何连接)

  • 有线carplay怎么连接(有线carplay怎么投屏)

    有线carplay怎么连接(有线carplay怎么投屏)

  • 荣耀30pro是几个颜色(荣耀30pro几个送话器)

    荣耀30pro是几个颜色(荣耀30pro几个送话器)

  • 红米k30屏幕120hz可以调吗(红米k30屏幕亮度多少尼特)

    红米k30屏幕120hz可以调吗(红米k30屏幕亮度多少尼特)

  • 怎样屏蔽微信群(怎样隐藏微信群聊)

    怎样屏蔽微信群(怎样隐藏微信群聊)

  • 华为手环4pro充电没反应(华为手环4PRO充电座和B19一样吗)

    华为手环4pro充电没反应(华为手环4PRO充电座和B19一样吗)

  • 西瓜视频直播抖音能看到吗(西瓜视频直播抖音看得到吗)

    西瓜视频直播抖音能看到吗(西瓜视频直播抖音看得到吗)

  • 3dmax和cad的区别(3dmax与cad关系)

    3dmax和cad的区别(3dmax与cad关系)

  • sim卡故障能修复吗(sim卡出故障是不是卡坏了)

    sim卡故障能修复吗(sim卡出故障是不是卡坏了)

  • med一al00是什么型号手机(med al00)

    med一al00是什么型号手机(med al00)

  • ios13共享位置对方知道吗(iphone共享位置会出错吗)

    ios13共享位置对方知道吗(iphone共享位置会出错吗)

  • 歌手耳朵上戴的耳机有什么用(歌手耳朵上戴的而返能起到什么作用)

    歌手耳朵上戴的耳机有什么用(歌手耳朵上戴的而返能起到什么作用)

  • 蓝屏代码0x000008e(蓝屏代码0X000000F4)

    蓝屏代码0x000008e(蓝屏代码0X000000F4)

  • 淘宝朋友代付退货钱退在哪里(淘宝朋友代付退款到哪里去)

    淘宝朋友代付退货钱退在哪里(淘宝朋友代付退款到哪里去)

  • 卖天猫积分对淘宝号有影响不(天猫积分出售有收的吗)

    卖天猫积分对淘宝号有影响不(天猫积分出售有收的吗)

  • 快手视频C类是什么意思(快手视频有几种类型)

    快手视频C类是什么意思(快手视频有几种类型)

  • qq聊天记录怎么分享给别人(qq聊天记录怎么找回最早的记录)

    qq聊天记录怎么分享给别人(qq聊天记录怎么找回最早的记录)

  • eliza和siri是什么关系(siri跟eliza是什么关系)

    eliza和siri是什么关系(siri跟eliza是什么关系)

  • iphone11出厂有膜吗(苹果11买回来有没有膜)

    iphone11出厂有膜吗(苹果11买回来有没有膜)

  • 苹果11pro max怎么禁止屏幕自动旋转(苹果11pro max怎么录屏)

    苹果11pro max怎么禁止屏幕自动旋转(苹果11pro max怎么录屏)

  • 天猫logo的设计含义是什么(天猫logo设计说明)

    天猫logo的设计含义是什么(天猫logo设计说明)

  • 二手数据的特点(二手数据的特点是什么?什么情况下可以使用二手数据)

    二手数据的特点(二手数据的特点是什么?什么情况下可以使用二手数据)

  • oem application profile是什么(oem application profile)

    oem application profile是什么(oem application profile)

  • iphone7无限重启白苹果(iphone7无限重启黑苹果)

    iphone7无限重启白苹果(iphone7无限重启黑苹果)

  • 连不上网是什么原因(连不上网是什么原因出现感叹号)

    连不上网是什么原因(连不上网是什么原因出现感叹号)

  • 三维重建(知识点详细解读、主要流程)(三维重建是啥意思)

    三维重建(知识点详细解读、主要流程)(三维重建是啥意思)

  • 对外支付人民币存在残缺污损的问题
  • 机票和发票是一样的吗
  • 中小型企业营业额和从业人数
  • 小微企业公司章程范本
  • 经营性收入包括投资收益吗
  • 技术入股亏损如何清算
  • 摊销保险费会计分录怎么写
  • 收到购货单位货款属于什么会计科目
  • 融资租入固定资产
  • 外购产品赠送他人合法吗
  • 哪些公司可以开咨询费发票
  • 低值易耗品进项税额转出账务处理
  • 差额征税扣除额大于收入时如何开票?
  • 怎么看是不是专用发票
  • 非营利组织免税资格可以免些什么税
  • 研发费用的检测费指的是什么内容
  • 工资薪金总额包括哪些内容
  • 金银首饰的消费税在什么环节
  • 微信占用空间大是怎么回事
  • 如何增强无线网卡的接收能力
  • 奖金发放如何做账
  • 如何关闭开始菜单快捷键
  • 两免三减半条件
  • 二手车增值税专用发票税率
  • 对视同销售行为应如何进行税务处理
  • vue3+vite在main.ts或者main.js文件中引入/App.vue报错(/App.vue不是模块)
  • PHP:zip_entry_compressionmethod()的用法_Zip函数
  • 建筑安装发票可以外地开吗
  • 如何分清福利性劳动
  • php获取访问用户的ip
  • 在计算应纳税所得额时,不允许作为税金项目
  • 增值税加计抵减怎么算
  • 机关事业单位购买茶叶违反什么规定
  • 报销钱大写数字
  • 给员工报销
  • 退休返聘人员算临时工吗
  • 云原生是什么
  • 基于Pytorch的风格转换
  • 遍历队列中所有数据元素
  • python 动态
  • 帝国cms怎么用
  • 购入库存商品会计分录摘要
  • 个人独资为什么不能叫公司
  • 增值税申报表各栏怎么填
  • 企业的职工福利费应当按照应付工资总额的14%计提
  • 小规模发票跨月冲红怎么做账
  • cms访问出错
  • python 函数 global
  • 更正申报补缴税款会影响记录
  • 个体户开发票超过定额是如何交税?
  • 福利费属于管理费吗
  • 开票软件怎样
  • 分公司非独立核算怎么报税
  • 收缩数据库日志文件对数据有影响吗
  • 公司对公账户转给个人
  • 发票的种类有哪些?存在哪些区别
  • 运输公司赔偿账务处理
  • 扶贫小额信贷分贷统还违规吗
  • 登记账簿遇到的问题及解决
  • 应付账款周转率越大越好还是越小越好?
  • 工商营业执照变更网上怎么操作
  • 结转损益不平是什么原因造成的
  • 安装sql server需要注意什么
  • 微软官方的网址是多少
  • phpstudy中phpmyadmin无法访问
  • 如何安装vmware10
  • windows 8.1 build 9600
  • win7升级win10之后视频解码能力变弱
  • xp直升win7
  • mac无线打印
  • 黑客怎样入侵别人手机
  • 索尼笔记本安装软件顺序
  • 搭建android开发环境时为什么要先安装jdk
  • jquery刷新页面的方法
  • shell脚本中实现rm -fr !(file1)
  • jQuery Timelinr实现垂直水平时间轴插件(附源码下载)
  • python3整除
  • 为什么征收城市建设维护税却不征收教育附加税
  • 资源税的征税范围一般包括
  • 实木地板什么
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设