位置: IT常识 - 正文

【深度学习】Pytorch实现CIFAR10图像分类任务测试集准确率达95%

编辑:rootadmin
【深度学习】Pytorch实现CIFAR10图像分类任务测试集准确率达95% 文章目录前言CIFAR10简介Backbone选择训练+测试训练环境及超参设置完整代码部分测试结果完整工程文件Reference前言

推荐整理分享【深度学习】Pytorch实现CIFAR10图像分类任务测试集准确率达95%,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

分享一下本人去年入门深度学习时,在CIFAR10数据集上做的图像分类任务,使用了多个主流的backbone网络,希望可以为同样想入门深度学习的同志们,提供一个方便上手、容易理解的参考教程。

CIFAR10简介

CIFAR-10数据集是图像分类领域经典的数据集,由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理得到,一共包含10个类别的 RGB彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck ),图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。 CIFAR-10 的图片样例如图所示   Pytorch中提供了如下命令可以直接将CIFAR10数据集下载到本地:

import torchvisiondataset = torchvision.datasets.CIFAR10(root, train=True, download=True, transform)root:数据集加载到本地的路径train=True:True表示加载训练集,False加载测试集download=True:True表示加载数据集到root,若数据集已经存在,则不会再加载transform:数据增强

  这里分享一个加载CIFAR10数据集的完整代码:

# 设置数据增强print('==> Preparing data..')transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])# 加载CIFAR10数据集trainset = torchvision.datasets.CIFAR10( root=opt.data, train=True, download=True, transform=transform_train)trainloader = torch.utils.data.DataLoader( trainset, batch_size=opt.batch_size, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10( root=opt.data, train=False, download=True, transform=transform_test)testloader = torch.utils.data.DataLoader( testset, batch_size=100, shuffle=False, num_workers=2)Backbone选择

本文主要尝试了以下几个主流的backbone网络,并在CIFAR10上实现了图像分类任务:

LetNetAlexNetVGGGoogLeNet(InceptionNet)ResNetDenseNetResNeXtSENetMobileNetv2-v3ShuffleNetv2EfficientNetB0Darknet53CSPDarknet53【深度学习】Pytorch实现CIFAR10图像分类任务测试集准确率达95%

  这里放上测试结果最好的ResNet模块的构建代码,其他代码放到最后完整工程backbone文件夹中:

"""pytorch实现ResNet50、ResNet101和ResNet152:"""import torchimport torch.nn as nnimport torchvisionimport torch.nn.functional as F# conv1 7 x 7 64 stride=2def Conv1(channel_in, channel_out, stride=2): return nn.Sequential( nn.Conv2d( channel_in, channel_out, kernel_size=7, stride=stride, padding=3, bias=False ), nn.BatchNorm2d(channel_out), # 会改变输入数据的值 # 节省反复申请与释放内存的空间与时间 # 只是将原来的地址传递,效率更好 nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=stride, padding=1) )# 构建ResNet18-34的网络基础模块class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out# 构建ResNet50-101-152的网络基础模块class Bottleneck(nn.Module): expansion = 4 def __init__(self, in_planes, planes, stride=1): super(Bottleneck, self).__init__() # 构建 1x1, 3x3, 1x1的核心卷积块 self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(self.expansion * planes) # 采用1x1的kernel,构建shout cut # 注意这里除了第一个bottleblock之外,都需要下采样,所以步长要设置为stride=2 self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = F.relu(self.bn2(self.conv2(out))) out = self.bn3(self.conv3(out)) out += self.shortcut(x) out = F.relu(out) return out# 搭建ResNet模板块class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super(ResNet, self).__init__() self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) # 逐层搭建ResNet self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512 * block.expansion, num_classes) # 参数初始化 # for m in self.modules(): # if isinstance(m, nn.Conv2d): # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # elif isinstance(m, nn.BatchNorm2d): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) # layers = [ ] 是一个列表 # 通过下面的for循环遍历配置列表,可以得到一个由 卷积操作、池化操作等 组成的一个列表layers # return nn.Sequential(*layers),即通过nn.Sequential函数将列表通过非关键字参数的形式传入(列表layers前有一个星号) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) out = out.view(out.size(0), -1) out = self.linear(out) return outdef ResNet18(): return ResNet(BasicBlock, [2, 2, 2, 2])def ResNet34(): return ResNet(BasicBlock, [3, 4, 6, 3])def ResNet50(): return ResNet(Bottleneck, [3, 4, 6, 3])def ResNet101(): return ResNet(Bottleneck, [3, 4, 23, 3])def ResNet152(): return ResNet(Bottleneck, [3, 8, 36, 3])# 测试# if __name__ == '__main__':# model = ResNet50()# print(model)## input = torch.randn(1, 3, 32, 32)# out = model(input)# print(out.shape)训练+测试训练环境及超参设置

本文的训练环境和超参数设置如下:

1块1080 Ti GPUepoch为100batch-size为128优化器:SGD学习率:余弦退火有序调整学习率

  主要步骤如下:

加载数据集

将数据集加载到本地按batch-size加载到dataLoader设置相关参数

指定GPU训练相关参数断点续训模型保存参数设置优化器设置学习率循环每个epoch

开启训练开启测试学习率调整数据可视化打印结果完整代码'''Train CIFAR10 with PyTorch.'''import torchvision.transforms as transformsimport timeimport torchimport torchvisionimport torch.nn as nnimport torch.optim as optimimport torch.backends.cudnn as cudnnfrom torch.utils.data import DataLoaderimport matplotlib.pyplot as pltimport osimport argparse# 导入模型from backbones.ResNet import ResNet18# 指定GPUos.environ['CUDA_VISIBLE_DEVICES'] = '1'# 用于计算GPU运行时间def time_sync(): # pytorch-accurate time if torch.cuda.is_available(): torch.cuda.synchronize() return time.time()# Trainingdef train(epoch): model.train() train_loss = 0 correct = 0 total = 0 train_acc = 0 # 开始迭代每个batch中的数据 for batch_idx, (inputs, targets) in enumerate(trainloader): # inputs:[b,3,32,32], targets:[b] # train_outputs:[b,10] inputs, targets = inputs.to(device), targets.to(device) # print(inputs.shape) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() # 计算损失 train_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() # 计算准确率 train_acc = correct / total # 每训练100个batch打印一次训练集的loss和准确率 if (batch_idx + 1) % 100 == 0: print('[INFO] Epoch-{}-Batch-{}: Train: Loss-{:.4f}, Accuracy-{:.4f}'.format(epoch + 1, batch_idx + 1, loss.item(), train_acc)) # 计算每个epoch内训练集的acc total_train_acc.append(train_acc)# Testingdef test(epoch, ckpt): global best_acc model.eval() test_loss = 0 correct = 0 total = 0 test_acc = 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(testloader): inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) loss = criterion(outputs, targets) test_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() test_acc = correct / total print( '[INFO] Epoch-{}-Test Accurancy: {:.3f}'.format(epoch + 1, test_acc), '\n') total_test_acc.append(test_acc) # 保存权重文件 acc = 100. * correct / total if acc > best_acc: print('Saving..') state = { 'net': model.state_dict(), 'acc': acc, 'epoch': epoch, } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state, ckpt) best_acc = accif __name__ == '__main__': # 设置超参 parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--data', type=str, default='cifar10') parser.add_argument('--T_max', type=int, default=100) parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') parser.add_argument('--checkpoint', type=str, default='checkpoint/ResNet18-CIFAR10.pth') opt = parser.parse_args() # 设置相关参数 device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu' best_acc = 0 # best test accuracy start_epoch = 0 # start from epoch 0 or last checkpoint epoch classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 设置数据增强 print('==> Preparing data..') transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) # 加载CIFAR10数据集 trainset = torchvision.datasets.CIFAR10( root=opt.data, train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader( trainset, batch_size=opt.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10( root=opt.data, train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader( testset, batch_size=100, shuffle=False, num_workers=2) # print(trainloader.dataset.shape) # 加载模型 print('==> Building model..') model = ResNet18().to(device) # DP训练 if device == 'cuda': model = torch.nn.DataParallel(model) cudnn.benchmark = True # 加载之前训练的参数 if opt.resume: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' checkpoint = torch.load(opt.checkpoint) model.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] # 设置损失函数与优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9, weight_decay=5e-4) # 余弦退火有序调整学习率 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.T_max) # ReduceLROnPlateau(自适应调整学习率) # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10) # 记录training和testing的acc total_test_acc = [] total_train_acc = [] # 记录训练时间 tic = time_sync() # 开始训练 for epoch in range(opt.epochs): train(epoch) test(epoch, opt.checkpoint) # 动态调整学习率 scheduler.step() # ReduceLROnPlateau(自适应调整学习率) # scheduler.step(loss_val) # 数据可视化 plt.figure() plt.plot(range(opt.epochs), total_train_acc, label='Train Accurancy') plt.plot(range(opt.epochs), total_test_acc, label='Test Accurancy') plt.xlabel('Epoch') plt.ylabel('Accurancy') plt.title('ResNet18-CIFAR10-Accurancy') plt.legend() plt.savefig('output/ResNet18-CIFAR10-Accurancy.jpg') # 自动保存plot出来的图片 plt.show() # 输出best_acc print(f'Best Acc: {best_acc * 100}%') toc = time_sync() # 计算本次运行时间 t = (toc - tic) / 3600 print(f'Training Done. ({t:.3f}s)')部分测试结果BackboneBest AccMobileNetv293.37%VGG1693.80%DenseNet12194.55%GoogLeNet95.02%ResNeXt29_32×4d95.18%ResNet5095.20%SENet1895.22%ResNet1895.23%完整工程文件

Pytorch实现CIFAR10图像分类任务测试集准确率达95%

Reference

CIFAR-10 数据集

深度学习入门基础教程(二) CNN做CIFAR10数据集图像分类 pytorch版代码

Pytorch CIFAR10 图像分类篇 汇总

pytorch-cifar:使用PyTorch在CIFAR10上为95.47%

本文链接地址:https://www.jiuchutong.com/zhishi/300727.html 转载请保留说明!

上一篇:vue中组件间通信的6种方式(vue之间的组件通信)

下一篇:自注意力(Self-Attention)与Multi-Head Attention机制详解(自注意力机制是什么)

  • iqoo8怎么设置应用消息不提醒(iqoo怎么设置应用锁)

    iqoo8怎么设置应用消息不提醒(iqoo怎么设置应用锁)

  • windows10激活密钥在哪(windows10激活密钥怎么获取)

    windows10激活密钥在哪(windows10激活密钥怎么获取)

  • 口罩面容id怎么设置(口罩face id)

    口罩面容id怎么设置(口罩face id)

  • apple music我喜欢的在哪里(apple music我喜欢列表)

    apple music我喜欢的在哪里(apple music我喜欢列表)

  • 小米妙享功能怎么使用的呢(小米妙享功能怎么关闭)

    小米妙享功能怎么使用的呢(小米妙享功能怎么关闭)

  • qq转发能屏蔽部分人吗(qq怎么在转发的时候屏蔽好友)

    qq转发能屏蔽部分人吗(qq怎么在转发的时候屏蔽好友)

  • 全民k歌网络错误是怎么回事(全民k歌网络错了怎么办)

    全民k歌网络错误是怎么回事(全民k歌网络错了怎么办)

  • 支付宝好友能看到啥(支付宝好友能看到我的哪些信息)

    支付宝好友能看到啥(支付宝好友能看到我的哪些信息)

  • 影响快手生态环境什么意思(快手生态环境包含哪些)

    影响快手生态环境什么意思(快手生态环境包含哪些)

  • 投屏播放总是自动退出(投屏播放总是自动播放)

    投屏播放总是自动退出(投屏播放总是自动播放)

  • 抖音本场点赞是什么意思(抖音本场点赞是音浪吗)

    抖音本场点赞是什么意思(抖音本场点赞是音浪吗)

  • 搜索微信号对方知道吗(搜索微信号对方头像一直没有变什么原因)

    搜索微信号对方知道吗(搜索微信号对方头像一直没有变什么原因)

  • 手机接不了电话但可以打出去(手机接不了电话也打不出去,能用流量)

    手机接不了电话但可以打出去(手机接不了电话也打不出去,能用流量)

  • qq群里发消息别人看不见(qq群发消息别人收不到)

    qq群里发消息别人看不见(qq群发消息别人收不到)

  • qq匿名投票管理员看到吗(qq匿名投票在哪儿弄)

    qq匿名投票管理员看到吗(qq匿名投票在哪儿弄)

  • 数据管理的三个阶段(数据管理的三个阶段的发展顺序正确的是)

    数据管理的三个阶段(数据管理的三个阶段的发展顺序正确的是)

  • 微信解封一年三次是怎么算的(微信解封一年三次怎么算时间)

    微信解封一年三次是怎么算的(微信解封一年三次怎么算时间)

  • 微信20分钟视频怎么发(微信20分钟视频发不过去)

    微信20分钟视频怎么发(微信20分钟视频发不过去)

  • 荣耀v30支持无线充电吗(荣耀80支持5g网络吗)

    荣耀v30支持无线充电吗(荣耀80支持5g网络吗)

  • vivos1是闪充吗(vivos1是多少w快充)

    vivos1是闪充吗(vivos1是多少w快充)

  • 华为原相机怎么调方形(华为原相机怎么调好看)

    华为原相机怎么调方形(华为原相机怎么调好看)

  • 小米无线车充支持哪些手机(小米无线车充支持苹果xr吗)

    小米无线车充支持哪些手机(小米无线车充支持苹果xr吗)

  • 怎么在手机屏幕上显示文字(怎么在手机屏幕上显示时间和天气)

    怎么在手机屏幕上显示文字(怎么在手机屏幕上显示时间和天气)

  • 怎么把系统装到固态硬盘(怎么把系统装到另外一个硬盘)

    怎么把系统装到固态硬盘(怎么把系统装到另外一个硬盘)

  • 怎样设置黑名单号码(怎样设置黑名单来电是关机)

    怎样设置黑名单号码(怎样设置黑名单来电是关机)

  • 新闻app大致有哪些功能(新闻app有哪些)

    新闻app大致有哪些功能(新闻app有哪些)

  • 微软正式宣布 Windows11:全新居中“开始”菜单,动态磁贴没了(微软正式宣布收购动视暴雪)

    微软正式宣布 Windows11:全新居中“开始”菜单,动态磁贴没了(微软正式宣布收购动视暴雪)

  • 账面价值与计税基础一般会产生差异的是
  • 项目奖金个人所得税怎么算
  • 进项税额转出是在借方还是贷方
  • 合同能源管理项目账务处理
  • 企业固定资产入账金额标准
  • 基建账的年终结转
  • 车辆后期保养费用
  • 店铺不盈利还开吗
  • 研发费用税点是什么意思
  • 当期应税销售收入是含税还是不含税
  • 沙特将开征增值税和特殊商品消费税
  • 个人购买办公楼出租要交税吗
  • 公司不盈利用交税吗
  • 不合规发票有哪些风险
  • 支付版权费用怎么入账
  • 存货减值账务处理 华图
  • 临时文件夹移动到c盘根目录下windows7
  • 财务红冲是什么意思
  • 软件开发公司账务怎么做
  • 新买的电脑如何激活windows
  • windows10如何更改时间
  • 未到期的应收票据向银行贴现什么时候计入短期贷款
  • 公会经费缴费单位应于每月
  • php7 数组
  • mac配置node环境
  • 发包工程补付工程款分录
  • 法人营业执照和非法人营业执照
  • 白鹤芋好养活吗
  • 个人销售非住宅土地增值税
  • 图卷积神经网络原理
  • javascript获取input的值并计算
  • 蛇形矩阵找数的位置
  • 期间费用计入产品成本的费用吗
  • node.js in action
  • 机票报销属于什么费
  • php session实例
  • 发放工资时扣除的保险怎么做
  • apache php mysql开发环境安装教程
  • vue数据加载完成显示页面过渡动画
  • 打开的ps关不掉
  • mysql零基础入门教程完整
  • 增值税发票税率1%
  • 新会计准则套期利息计算
  • ibm.data.db2
  • 金蝶kis云专业版使用教程
  • 承兑汇票找公司贴现违法吗
  • 当月的进项税可以不认证吗?
  • 应交税费借方是增加还是减少
  • 航信服务费减免怎么填
  • 应付职工薪酬代扣社保
  • 电子承兑汇票是到期日前10天提示承兑吗
  • 高速过路费怎么补交
  • 基本生产成本科目应该按成本计算对象
  • 税额抵减的账务处理
  • 一次性付款的优势
  • mysql在cmd命令操作
  • 文档介绍
  • xp系统设置壁纸
  • windowsqq截屏
  • win7 64位系统使用dos命令快速提高u盘传输速度的技巧
  • win8.1开不了机怎么办
  • cocos点击事件
  • nodejs阿里云
  • cocos2dx 2.2.2
  • Bullet(Cocos2dx)之创建地形
  • unit uniform
  • jquery输入框改变事件
  • dos命令到一个文件夹
  • unity调用c++封装的dll
  • android xml文件有哪几种布局方式
  • Unity3D的iTween
  • python爬虫类
  • python仿站软件官网
  • 国家税务总局在哪
  • 电子发票软件怎么打开
  • 关税由谁来承担
  • 2020年军人自主择业条件
  • 试验费属于什么税收编码
  • 小微企业房产税优惠减免政策
  • 购买税控盘怎么抵扣
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设