位置: IT常识 - 正文

Pytorch深度学习实战3-7:详解数据加载DataLoader与模型处理

编辑:rootadmin
原力计划Pytorch深度学习实战3-7:详解数据加载DataLoader与模型处理 目录1 数据集Dataset2 数据加载DataLoader3 常用预处理方法4 模型处理5 实例:MNIST数据集处理1 数据集Dataset

推荐整理分享Pytorch深度学习实战3-7:详解数据加载DataLoader与模型处理,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

Dataset类是Pytorch中图像数据集操作的核心类,Pytorch中所有数据集加载类都继承自Dataset父类。当我们自定义数据集处理时,必须实现Dataset类中的三个接口:

初始化

def __init__(self)

构造函数,定义一些数据集的公有属性,如数据集下载地址、名称等数据集大小

def __len__(self)

返回数据集大小,不同的数据集有不同的衡量数据量的方式数据集索引

def __getitem__(self, index):

支持数据集索引功能,以实现形如dataset[i]得到数据集中的第i + 1个数据的功能。__getitem__是后期迭代数据时执行的具体函数,其返回值决定了循环变量,例如

class data(Dataset)...def __getitem__(self, idx: int):if self.transforms:img = self.transforms(img)return img, label# 返回的值即为后续迭代的循环变量for images, labels in dataLoader:...2 数据加载DataLoader

为什么有了数据集Dataset还需要数据加载器DataLoader呢?原因在于神经网络需要进一步借助DataLoader对数据进行划分,也就是我们常说的batch,此外DataLoader还实现了打乱数据集、多线程等操作。

DataLoader本质是一个可迭代对象,可以使用形如

for inputs, labels in dataloaders

进行可迭代对象的访问。

我们一般不需要去实现DataLoader的接口,只需要在构造函数中指定相应的参数即可,比如常见的batch_size,shuffle等参数。

Pytorch深度学习实战3-7:详解数据加载DataLoader与模型处理

下面这张图非常好地说明了Dataset和DataLoader的关系

接下来总结数据构造的三步法

继承Dataset对象,并实现__len__()、__getitem__()魔法方法,该步骤的主要目的在于将文件形式的数据集处理为模型可用的标准数据格式,并加载到内存中;用DataLoader对象封装Dataset,使其成为可迭代对象;遍历DataLoader对象以将数据加载到模型中进行训练。3 常用预处理方法

在数据集Dataset的__getitem__()中利用torchvision.transforms进行数据预处理与变换

常见的数据预处理变换方法总结如下表

序号变换含义1RandomCrop(size, ...)对输入图像依据给定size随机裁剪2CenterCrop(size, ...)对输入图像依据给定size从中心裁剪3RandomResizedCrop(size, ...)对输入图像随机长宽比裁剪,再放缩到给定size4FiveCrop(size, ...)对输入图像进行上下左右及中心裁剪,返回五张图像(size)组成的四维张量5TenCrop(size, vertical_flip=False)对输入图像进行上下左右及中心裁剪,再全部翻转(水平或垂直),返回十张图像(size)组成的四维张量6RandomHorizontalFlip(p=0.5)对输入图像按概率p随机进行水平翻转7RandomVerticalFlip(p=0.5)对输入图像按概率p随机进行垂直翻转8RandomRotation(degree, ...)对输入图像在degree内随机旋转某角度9Resize(size, ...)对输入图像重置分辨率10Normalize(mean, std)对输入图像各通道进行标准化11ToTensor()将输入图像或ndarray 转换为tensor并归一化12Pad(padding, fill=0, padding_mode=‘constant’)对输入图像进行填充13ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)对输入图像修改亮度、对比度、饱和度、色度等14Grayscale(num_output_channels=1)对输入图像转灰度15LinearTransformation(matrix)对输入图像进行线性变换16RandomAffine(...)对输入图像进行仿射变换17RandomGrayscale(p=0.1)对输入图像按概率p随机转灰度18ToPILImage(mode=None)对输入图像转PIL格式图像19RandomOrder()随机打乱transforms操作顺序4 模型处理

考虑以下场景:

网络的部分层级结构已经收敛、无需调整;大型复杂网络需要微调(Fine-tune)某些结构或参数;希望基于已训练好的模型进行改善或其他研究工作。

这些场景下重新通过数据集训练整个神经网络并无必要,甚至会使模型不稳定,因此引入预训练(pretrained)。Pytorch允许用户保存已训练好的模型,或加载其他模型,避免往复的无谓重训练,其中模型参数文件以.pth为后缀

# 保存已训练模型torch.save(model.state_dict(), path)# 加载预训练模型model.load_state_dict(torch.load(path), device)

通过设置模型某些层可学习参数的requires_grad属性为False即可固定这部分参数不被后续学习过程影响。深度学习框架应用优势之一在于预设了对GPU的支持,大大提高模型处理与训练的效率。Pytorch中通过mode.to(device)方法将模型部署到指定设备上(CPU/GPU),范式如下:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model.to(device)

工程上也常使用torch.nn.DataParallel(model, devices)来处理多GPU并行运算,其原理是:首先将模型加载到主GPU上,再将模型从主GPU产生若干副本到其余GPU,随后将一个batch中的数据按维度划分为不同的子任务给各GPU进行前向传播,得到的损失会被累积到主GPU上并由主GPU反向传播更新参数,最后将更新参数拷贝到其余GPU以开始下一轮训练。

5 实例:MNIST数据集处理

下面给出了处理MNIST手写数据集的完整代码,可以用于加深对数据处理流程的理解

from abc import abstractmethodimport numpy as npfrom torchvision.datasets import mnistfrom torch.utils.data import Datasetfrom PIL import Imageclass mnistData(Dataset): ''' * @breif: MNIST数据集抽象接口 * @param[in]: dataPath -> 数据集存放路径 * @param[in]: transforms -> 数据集变换 ''' def __init__(self, dataPath: str, transforms=None) -> None: super().__init__() self.dataPath = dataPath self.transforms = transforms self.data, self.label = [], [] def __len__(self) -> int: return len(self.label) def __getitem__(self, idx: int): img = self.data[idx] if self.transforms: img = self.transforms(img) return img, self.label[idx] @abstractmethod def plot(self, index: int) -> None: pass @abstractmethod def load(self) -> list: pass def plotData(self, index: int, info: str=None) -> None: ''' * @breif: 可视化训练数据 * @param[in]: index -> 数据集索引 * @param[in]: info -> 备注信息 * @retval: None ''' print(info, " --index:", index, "--label:", self.label[index]) if info else \ print(" --index:", index, "--label:", self.label[index]) img = Image.fromarray(np.uint8(self.data[index])) img.show() def loadData(self, train: bool) -> list: ''' * @breif: 下载与加载数据集 * @param[in]: train -> 是否为训练集 * @retval: 数据与标签列表 ''' # 如果指定目录下不存在数据集则下载 dataSet = mnist.MNIST(self.dataPath, train=train, download=True) # 初始化数据与标签 data = [ i[0] for i in dataSet ] label = [ i[1] for i in dataSet ] return data, labelclass mnistTrainData(mnistData): ''' * @breif: MNIST训练集 * @param[in]: dataPath -> 数据集存放路径 * @param[in]: transforms -> 数据集变换 ''' def __init__(self, dataPath: str, transforms=None) -> None: super().__init__(dataPath, transforms=transforms) self.data, self.label = self.load() def plot(self, index: int) -> None: self.plotData(index, "trainSet data") def load(self) -> list: return self.loadData(train=True)class mnistTestData(mnistData): ''' * @breif: MNIST测试集 * @param[in]: dataPath -> 数据集存放路径 * @param[in]: transforms -> 数据集变换 ''' def __init__(self, dataPath: str, transforms=None) -> None: super().__init__(dataPath, transforms=transforms) self.data, self.label = self.load() def plot(self, index: int) -> None: self.plotData(index, "testSet data") def load(self) -> list: return self.loadData(train=False)

本文链接地址:https://www.jiuchutong.com/zhishi/300789.html 转载请保留说明!

上一篇:React中使用Redux (二) - 通过react-redux库连接React和Redux(react redux reducer)

下一篇:多模态特征融合:图像、语音、文本如何转为特征向量并进行分类(多模态特征融合pytorch)

  • 苹果手机应用与数据在哪里(苹果手机应用与管理在哪里找到)

    苹果手机应用与数据在哪里(苹果手机应用与管理在哪里找到)

  • OPPO手机怎么设置儿童使用时间(oppo手机怎么设置锁屏密码)

    OPPO手机怎么设置儿童使用时间(oppo手机怎么设置锁屏密码)

  • 小米10pro支持无线充电的吗(小米10pro支持无线充电功能吗)

    小米10pro支持无线充电的吗(小米10pro支持无线充电功能吗)

  • 苹果怎么关掉来电闪光灯(苹果怎么关掉来电声音变小)

    苹果怎么关掉来电闪光灯(苹果怎么关掉来电声音变小)

  • p30和p30pro摄像对比(p30和p30pro相机)

    p30和p30pro摄像对比(p30和p30pro相机)

  • word表格怎么调整位置(word表格怎么调整某一个单元格)

    word表格怎么调整位置(word表格怎么调整某一个单元格)

  • 快手浏览别人作品别人有记录吗(快手浏览别人作品怎么不留痕迹)

    快手浏览别人作品别人有记录吗(快手浏览别人作品怎么不留痕迹)

  • 苹果xr连接wifi不稳定(苹果xr连接wifi不显示图标)

    苹果xr连接wifi不稳定(苹果xr连接wifi不显示图标)

  • .msi是什么文件(msi格式的文件)

    .msi是什么文件(msi格式的文件)

  • 开启扬声器是什么意思(扬声器从哪里开)

    开启扬声器是什么意思(扬声器从哪里开)

  • 安全模式按f几(开机进安全模式按f几)

    安全模式按f几(开机进安全模式按f几)

  • 手机飞行模式还能查到轨迹吗(手机飞行模式还能被追踪到吗)

    手机飞行模式还能查到轨迹吗(手机飞行模式还能被追踪到吗)

  • 荣耀30有防抖吗(荣耀30有防抖功能吗)

    荣耀30有防抖吗(荣耀30有防抖功能吗)

  • 7和7plus有什么区别(7和7p哪个性价比高)

    7和7plus有什么区别(7和7p哪个性价比高)

  • ipad电池百分1不停重启(ipad2020电池容量不是百分百)

    ipad电池百分1不停重启(ipad2020电池容量不是百分百)

  • vivox9可不可以用5g网(vivox9能用5g卡吗)

    vivox9可不可以用5g网(vivox9能用5g卡吗)

  • 处理器个数更改有何用(处理器修改)

    处理器个数更改有何用(处理器修改)

  • 号码拉黑了打电话有提示吗(号码拉黑了打电话过去还有信息不)

    号码拉黑了打电话有提示吗(号码拉黑了打电话过去还有信息不)

  • word打竖行文字居中(word打竖排字)

    word打竖行文字居中(word打竖排字)

  • 华为mate30怎么设置指纹应用锁(华为mate30怎么设置下面三个键)

    华为mate30怎么设置指纹应用锁(华为mate30怎么设置下面三个键)

  • 手机后台运行怎么打开(苹果手机14怎么关闭后应用运行)

    手机后台运行怎么打开(苹果手机14怎么关闭后应用运行)

  • 苹果手机怎么截长图(苹果手机怎么截图全屏长图)

    苹果手机怎么截长图(苹果手机怎么截图全屏长图)

  • 联华充值卡微信能用吗(联华充值卡怎么充值到微信账户)

    联华充值卡微信能用吗(联华充值卡怎么充值到微信账户)

  • 荣耀play怎么反向充电(荣耀play4怎样返回)

    荣耀play怎么反向充电(荣耀play4怎样返回)

  • xr如何快速切换主副卡(xr怎么开启切换控制)

    xr如何快速切换主副卡(xr怎么开启切换控制)

  • 探探卸载了别人还能看到我吗(探探卸载了别人还能看到距离吗)

    探探卸载了别人还能看到我吗(探探卸载了别人还能看到距离吗)

  • 抖音怎样才能开直播(抖音怎样才能开直播卖货)

    抖音怎样才能开直播(抖音怎样才能开直播卖货)

  • DRWTSN16.EXE是病毒程序吗 DRWTSN16进程是不是病毒(winspool.drv病毒)

    DRWTSN16.EXE是病毒程序吗 DRWTSN16进程是不是病毒(winspool.drv病毒)

  • Vue3 京东到家项目实战第一篇(首页及登录功能开发) 进阶式掌握vue3完整知识体系(京东到家的物流模式)

    Vue3 京东到家项目实战第一篇(首页及登录功能开发) 进阶式掌握vue3完整知识体系(京东到家的物流模式)

  • 公司税后利润怎么算
  • 购车增值税可以抵扣多少
  • 进项抵扣联丢了怎么办
  • 商业用房出租税率是多少
  • 做网站的费用会计分录
  • 房地产行业企业所得税政策
  • 人工费能不能抵扣进项税
  • 普通增值税发票税号
  • 发票代签怎么处理
  • 软件开发成本核算模板
  • 销售自己使用过的物品
  • 营业外收入在资产负债表怎么填
  • 社保利息是什么意思
  • 个人独资企业缴纳个人所得税
  • 税负的含义
  • 备品的定义
  • 专家评审费如何报销费用
  • 营改增后哪些费用可以抵扣
  • 三证合一之前
  • 工资表中有哪些项目
  • 税款滞纳金怎么入账
  • 发票记账联丢失怎么写情况说明
  • 反写状态已反写是什么意思
  • 管理费用劳务费现金流
  • 招标文件中资金性质应填什么
  • 广告费和业务宣传费税前扣除基数
  • 个体户流水过大怎么避税
  • 长期投资损失的确认
  • dghm.exe是什么程序
  • 期间损益结转有余额
  • 交易性金融资产的交易费用计入哪里
  • win11预览版dev改beta
  • 建材网上销售平台有哪些
  • 卡特迈国家公园地图
  • 印花税贴花怎么贴划线
  • 用java写一个helloworld
  • 要点初见:Stable Diffusion NovelAI模型优质文字Tag汇总与实践【魔咒汇总】
  • thinkphp d
  • php中获取当前时间
  • 其他非流动资产包括哪些
  • window.eval方法
  • php中表单的使用
  • javascript数据类型分为哪两类
  • 土地开发是什么
  • 社保基数怎么申请下调
  • sql数据库移动
  • 织梦标签教程
  • 小规模纳税人现代服务税率
  • 个人独资企业应税生产经营所得可以扣除税金支付
  • 契税是指什么?
  • 限售股算不算账户资产
  • 税控技术服务费会计分录
  • 购买员工团体意外险需要缴纳个税么
  • 购买的税控设备
  • 可供分配的利润分配顺序
  • 零余额账户银行回单
  • 消费税什么时候计入成本
  • 利润表的未分配利润是哪个
  • 一般纳税人的进项税额可以抵扣吗
  • 利息支出和利息收入区别
  • 减免增值税如何申报
  • 出租人融资租赁的判断标准
  • sql语句编译执行过程
  • 注册表隐藏桌面图标
  • linux系统硬盘分区
  • ubuntu20.04配置
  • macbook 苹果系统
  • html5能做游戏吗
  • ExtJS扩展 垂直tabLayout实现代码
  • windows运行bat文件命令
  • python教程详细
  • Unity3D游戏开发培训课程大纲
  • vue+node+webpack环境搭建教程
  • JavaScript中Number.NEGATIVE_INFINITY值的使用详解
  • android底部弹出页面
  • javascript for in
  • android 自定义actionbar
  • 江苏国税电子税局
  • 装卸搬运费属于
  • 无锡市国税
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设