位置: IT常识 - 正文

猿创征文|深度学习基于ResNet18网络完成图像分类(猿创部落是干什么的)

编辑:rootadmin
猿创征文|深度学习基于ResNet18网络完成图像分类 一.前言

推荐整理分享猿创征文|深度学习基于ResNet18网络完成图像分类(猿创部落是干什么的),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:猿创部落科技有限公司,猿文教育科技有限公司怎么样,猿文教育科技有限公司怎么样,猿创设计科技有限公司,猿创教育,猿文教育科技有限公司怎么样,猿创教育,猿创设计科技有限公司,内容如对您有帮助,希望把文章链接给更多的朋友!

本次任务是利用ResNet18网络实践更通用的图像分类任务。

ResNet系列网络,图像分类领域的知名算法,经久不衰,历久弥新,直到今天依旧具有广泛的研究意义和应用场景。被业界各种改进,经常用于图像识别任务。

今天主要介绍一下ResNet-18网络结构的案例,其他深层次网络,可以依次类推。

ResNet-18,数字代表的是网络的深度,也就是说ResNet18 网络就是18层的吗?实则不然,其实这里的18指定的是带有权重的 18层,包括卷积层和全连接层,不包括池化层和BN层。

图像分类(Image Classification)是计算机视觉中的一个基础任务,将图像的语义将不同图像划分到不同类别。很多任务也可以转换为图像分类任务。比如人脸检测就是判断一个区域内是否有人脸,可以看作一个二分类的图像分类任务。

数据集:使用的计算机视觉领域的经典CIFAR-10数据集网络层:网络为ResNet18模型优化器:优化器为Adam优化器损失函数:损失函数为交叉熵损失评价指标:评价指标为准确率

 ResNet 网络简介:

 

二.数据预处理2.1 数据集介绍

CIFAR-10数据集包含了10种不同的类别、共60,000张图像,其中每个类别的图像都是6000张,图像大小均为32×3232×32像素。

2.2 数据读取

在本实验中,将原始训练集拆分成了train_set、dev_set两个部分,分别包括40 000条和10 000条样本。将data_batch_1到data_batch_4作为训练集,data_batch_5作为验证集,test_batch作为测试集。 最终的数据集构成为:

训练集:40 000条样本。验证集:10 000条样本。测试集:10 000条样本。

读取一个batch数据的代码如下所示:

import osimport pickleimport numpy as npdef load_cifar10_batch(folder_path, batch_id=1, mode='train'): if mode == 'test': file_path = os.path.join(folder_path, 'test_batch') else: file_path = os.path.join(folder_path, 'data_batch_'+str(batch_id)) #加载数据集文件 with open(file_path, 'rb') as batch_file: batch = pickle.load(batch_file, encoding = 'latin1') imgs = batch['data'].reshape((len(batch['data']),3,32,32)) / 255. labels = batch['labels'] return np.array(imgs, dtype='float32'), np.array(labels)imgs_batch, labels_batch = load_cifar10_batch(folder_path='datasets/cifar-10-batches-py', batch_id=1, mode='train')猿创征文|深度学习基于ResNet18网络完成图像分类(猿创部落是干什么的)

查看数据的维度:

#打印一下每个batch中X和y的维度print ("batch of imgs shape: ",imgs_batch.shape, "batch of labels shape: ", labels_batch.shape)

batch of imgs shape:  (10000, 3, 32, 32) batch of labels shape:  (10000,)

可视化观察其中的一张样本图像和对应的标签,代码如下所示:

%matplotlib inlineimport matplotlib.pyplot as pltimage, label = imgs_batch[1], labels_batch[1]print("The label in the picture is {}".format(label))plt.figure(figsize=(2, 2))plt.imshow(image.transpose(1,2,0))plt.savefig('cnn-car.pdf')

2.3 构造Dataset类

构造一个CIFAR10Dataset类,其将继承自paddle.io.DataSet类,可以逐个数据进行处理。代码实现如下:

import paddleimport paddle.io as iofrom paddle.vision.transforms import Normalizeclass CIFAR10Dataset(io.Dataset): def __init__(self, folder_path='/home/aistudio/cifar-10-batches-py', mode='train'): if mode == 'train': #加载batch1-batch4作为训练集 self.imgs, self.labels = load_cifar10_batch(folder_path=folder_path, batch_id=1, mode='train') for i in range(2, 5): imgs_batch, labels_batch = load_cifar10_batch(folder_path=folder_path, batch_id=i, mode='train') self.imgs, self.labels = np.concatenate([self.imgs, imgs_batch]), np.concatenate([self.labels, labels_batch]) elif mode == 'dev': #加载batch5作为验证集 self.imgs, self.labels = load_cifar10_batch(folder_path=folder_path, batch_id=5, mode='dev') elif mode == 'test': #加载测试集 self.imgs, self.labels = load_cifar10_batch(folder_path=folder_path, mode='test') self.transform = Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010], data_format='CHW') def __getitem__(self, idx): img, label = self.imgs[idx], self.labels[idx] img = self.transform(img) return img, label def __len__(self): return len(self.imgs)paddle.seed(100)train_dataset = CIFAR10Dataset(folder_path='datasets/cifar-10-batches-py', mode='train')dev_dataset = CIFAR10Dataset(folder_path='datasets/cifar-10-batches-py', mode='dev')test_dataset = CIFAR10Dataset(folder_path='datasets/cifar-10-batches-py', mode='test')三、模型构建

使用飞桨高层API中的Resnet18进行图像分类实验。

from paddle.vision.models import resnet18resnet18_model = resnet18()

飞桨高层 API是对飞桨API的进一步封装与升级,提供了更加简洁易用的API,进一步提升了飞桨的易学易用性。其中,飞桨高层API封装了以下模块:

Model类,支持仅用几行代码完成模型的训练;图像预处理模块,包含数十种数据处理函数,基本涵盖了常用的数据处理、数据增强方法;计算机视觉领域和自然语言处理领域的常用模型,包括但不限于mobilenet、resnet、yolov3、cyclegan、bert、transformer、seq2seq等等,同时发布了对应模型的预训练模型,可以直接使用这些模型或者在此基础上完成二次开发。四、模型训练

复用RunnerV3类,实例化RunnerV3类,并传入训练配置。 使用训练集和验证集进行模型训练,共训练30个epoch。 在实验中,保存准确率最高的模型作为最佳模型。代码实现如下:

import paddle.nn.functional as Fimport paddle.optimizer as optfrom nndl import RunnerV3, Accuracy#指定运行设备use_gpu = True if paddle.get_device().startswith("gpu") else Falseif use_gpu: paddle.set_device('gpu:0')#学习率大小lr = 0.001 #批次大小batch_size = 64 #加载数据train_loader = io.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)dev_loader = io.DataLoader(dev_dataset, batch_size=batch_size)test_loader = io.DataLoader(test_dataset, batch_size=batch_size) #定义网络model = resnet18_model#定义优化器,这里使用Adam优化器以及l2正则化策略,相关内容在7.3.3.2和7.6.2中会进行详细介绍optimizer = opt.Adam(learning_rate=lr, parameters=model.parameters(), weight_decay=0.005)#定义损失函数loss_fn = F.cross_entropy#定义评价指标metric = Accuracy(is_logist=True)#实例化RunnerV3runner = RunnerV3(model, optimizer, loss_fn, metric)#启动训练log_steps = 3000eval_steps = 3000runner.train(train_loader, dev_loader, num_epochs=30, log_steps=log_steps, eval_steps=eval_steps, save_path="best_model.pdparams")

可视化观察训练集与验证集的准确率及损失变化情况。

from nndl import plotplot(runner, fig_name='cnn-loss4.pdf')

在本实验中,使用了第7章中介绍的Adam优化器进行网络优化,如果使用SGD优化器,会造成过拟合的现象,在验证集上无法得到很好的收敛效果。可以尝试使用第7章中其他优化策略调整训练配置,达到更高的模型精度。

五、模型评价

使用测试数据对在训练过程中保存的最佳模型进行评价,观察模型在测试集上的准确率以及损失情况。代码实现如下:

# 加载最优模型runner.load_model('best_model.pdparams')# 模型评价score, loss = runner.evaluate(test_loader)print("[Test] accuracy/loss: {:.4f}/{:.4f}".format(score, loss))

[Test] accuracy/loss: 0.7234/0.8324

六、模型预测¶

同样地,也可以使用保存好的模型,对测试集中的数据进行模型预测,观察模型效果,具体代码实现如下:

#获取测试集中的一个batch的数据X, label = next(test_loader())logits = runner.predict(X)#多分类,使用softmax计算预测概率pred = F.softmax(logits)#获取概率最大的类别pred_class = paddle.argmax(pred[2]).numpy()label = label[2][0].numpy()#输出真实类别与预测类别print("The true category is {} and the predicted category is {}".format(label[0], pred_class[0]))#可视化图片plt.figure(figsize=(2, 2))imgs, labels = load_cifar10_batch(folder_path='/home/aistudio/datasets/cifar-10-batches-py', mode='test')plt.imshow(imgs[2].transpose(1,2,0))plt.savefig('cnn-test-vis.pdf')

The true category is 8 and the predicted category is 8

真实是8,预测是8。ship

本文链接地址:https://www.jiuchutong.com/zhishi/300313.html 转载请保留说明!

上一篇:备战数学建模45-粒子群算法优化BP神经网络(攻坚站10)(数学建模心态崩了)

下一篇:ChatGPT在编程中的应用(编程中char什么意思)

  • 空调抽湿温度多少度合适(空调抽湿)(空调抽湿温度多少合适)

    空调抽湿温度多少度合适(空调抽湿)(空调抽湿温度多少合适)

  • bilibili怎么投屏到电视(bilibili怎么投屏连续播放)

    bilibili怎么投屏到电视(bilibili怎么投屏连续播放)

  • 花小猪怎么预约打车(花小猪怎么预约打车时间视频)

    花小猪怎么预约打车(花小猪怎么预约打车时间视频)

  • 新东方云教室怎么关闭摄像头(新东方云教室怎么开低性能)

    新东方云教室怎么关闭摄像头(新东方云教室怎么开低性能)

  • pdf是什么格式的文件(文件pdf格式怎么弄)

    pdf是什么格式的文件(文件pdf格式怎么弄)

  • 微信红包可以退回吗(微信红包可以退换回去吗)

    微信红包可以退回吗(微信红包可以退换回去吗)

  • 固态硬盘读取写入速度多少正常(固态硬盘读取写入速度重要吗)

    固态硬盘读取写入速度多少正常(固态硬盘读取写入速度重要吗)

  • 淘宝我的提问在哪删除(淘宝我的提问在哪里为什么不显示)

    淘宝我的提问在哪删除(淘宝我的提问在哪里为什么不显示)

  • 垂直同步开了会卡吗(垂直同步开了会怎么样)

    垂直同步开了会卡吗(垂直同步开了会怎么样)

  • airpods右边连不上(airpods右耳连接不了)

    airpods右边连不上(airpods右耳连接不了)

  • rf1是键盘上哪个键(lf1 rf1是什么键)

    rf1是键盘上哪个键(lf1 rf1是什么键)

  • 华为取消屏保设置方法是什么(华为手机取消屏幕保护设置)

    华为取消屏保设置方法是什么(华为手机取消屏幕保护设置)

  • 华为nova5pro可以遥控空调吗(华为nova5pro可以开空调吗)

    华为nova5pro可以遥控空调吗(华为nova5pro可以开空调吗)

  • 苹果已购买的项目隐藏了怎么找出来(苹果已购买的项目怎么退订)

    苹果已购买的项目隐藏了怎么找出来(苹果已购买的项目怎么退订)

  • 微信小蓝圈怎么弄(微信篮圈怎么设置)

    微信小蓝圈怎么弄(微信篮圈怎么设置)

  • 电脑钉钉连麦对方听不到声音(电脑钉钉连麦对方听不到我的声音是什么原因?)

    电脑钉钉连麦对方听不到声音(电脑钉钉连麦对方听不到我的声音是什么原因?)

  • sata线分2.0和3.0吗(sata线分2.0和3.0外观)

    sata线分2.0和3.0吗(sata线分2.0和3.0外观)

  • 已移除蜂窝移动号码什么意思(已移除蜂窝移动怎么恢复)

    已移除蜂窝移动号码什么意思(已移除蜂窝移动怎么恢复)

  • sata2和sata3外观区别(sata2和sata3接口区别大吗)

    sata2和sata3外观区别(sata2和sata3接口区别大吗)

  • 被注销的qq长什么样(被注销的qq长什么样图片)

    被注销的qq长什么样(被注销的qq长什么样图片)

  • vivoz5反向充电怎么用(vivoy51s反向充电)

    vivoz5反向充电怎么用(vivoy51s反向充电)

  • 抖音清空消息方法介绍(抖音清空消息方式有哪些)

    抖音清空消息方法介绍(抖音清空消息方式有哪些)

  • 微信语音聊天可以录音吗(微信语音聊天可以作为法律证据吗?)

    微信语音聊天可以录音吗(微信语音聊天可以作为法律证据吗?)

  • 微信金山文档怎么用电脑打开(微信金山文档怎么转换成文件)

    微信金山文档怎么用电脑打开(微信金山文档怎么转换成文件)

  • 小米内存卡在哪里打开(小米内存卡在哪个位置)

    小米内存卡在哪里打开(小米内存卡在哪个位置)

  • oppo热点资讯怎么关(oppo热点资讯怎么关掉通知)

    oppo热点资讯怎么关(oppo热点资讯怎么关掉通知)

  • mac忘记开机密码(双系统mac忘记开机密码)

    mac忘记开机密码(双系统mac忘记开机密码)

  • 华为助手怎么叫出来(华为手机助手怎么喊出来)

    华为助手怎么叫出来(华为手机助手怎么喊出来)

  • 工商年报的纳税总额是什么
  • 企业的书报费应计入销售费用
  • 个体工商户单位性质怎么填
  • 坏账准备转回的条件
  • 企业合并吸收税务处理
  • 高新技术认定标准条件是什么
  • 分公司预缴企业所得税总公司可以抵扣吗
  • 地税补缴社保
  • 租的厂房水电费开不了发票怎么办
  • 残保金是谁支付给单位?
  • 合同租金总收入怎么填
  • 原材料的归集和整理
  • 递延收益没有应列入哪个科目
  • 小规模纳税人销售农产品税率是多少
  • 未认证的进项也就是库存
  • 个体广告用去税务报账吗?
  • 自然人股权转让涉税信息怎么填
  • 虚开进项税额转出会计分录
  • 什么是差额征税,什么情况下适用差额征税
  • 税控盘的购买流程
  • 红字发票开具只能针对一份发票 不可以只冲红其中一部分吗?
  • 税务会计核算范围
  • 报个税系统叫啥
  • 应交税费的期初余额是借还是贷
  • 多缴的社保退还给员工是否还要算个税
  • 建安企业结转成本如何计算
  • 属于留存收益的是
  • 服务业成本会计分析
  • 劳务费没发票怎样下账
  • 总账建账科目顺序
  • 商会收到的会费要交企业所得税
  • 失控发票的企业怎么处理
  • win11 zen2
  • 企业出租涉及到的税收
  • php使用内置函数的过程
  • 设计费包含概算费用吗
  • 营业外收支计算公式
  • PHP:imagecreatetruecolor()的用法_GD库图像处理函数
  • 动静结合会计等式的不会重复算利润吗
  • vue父子组件传值
  • 转回已核销的坏账分录
  • 应收账款周转率多少合适
  • 个体工商户城市维护建设税
  • 财报层次和认定层次
  • 发票校验码是什么在哪
  • 国库集中支付发送签收失败
  • 应收票据贴现的会计处理
  • php怎么连接sqlserver
  • 会计年度对账
  • 暂估和冲暂估分录
  • 财产转让按什么计征
  • 地方水利建设基金减免政策2023
  • 京东提现一般多久到账
  • 代销商品怎么开票
  • 预付账款的相关认定
  • 纳税人将外购的货物用于非应税项目
  • 出口退税申报的报关单无电子信息
  • 职工福利费和工会经费
  • 母子公司资金往来财税问题
  • 错误凭证如何处理
  • 职工福利费核算哪些内容
  • win7系统IE浏览器打开跳转到360浏览器,怎么阻止
  • win7怎么隐藏我的电脑
  • centos 安装
  • 一键ghost还原备份
  • xp桌面浏览器图标不见了
  • windows 8.1更新
  • windows8中文版是什么版本
  • 【Cocso2d-x Lua笔记五】quick中的display
  • 简述javascript执行原理
  • nodejs如何配置环境变量
  • shell脚本显示进度条
  • position属性含义
  • android 4.2
  • 服务协议属于哪类合同
  • 两江新区钓鱼地方
  • 汽车销售流程有哪些环节?每个环节的主要内容是什么?
  • 留抵税额退税政策2022年14号文件
  • 地税局属于国家公务员吗
  • 美国纽约购物
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设