位置: IT常识 - 正文

Pytorch:手把手教你搭建简单的卷积神经网络(CNN),实现MNIST数据集分类任务(pytorch教程)

编辑:rootadmin
Pytorch:手把手教你搭建简单的卷积神经网络(CNN),实现MNIST数据集分类任务 关于一些代码里的解释,可以看我上一篇发布的文章,里面有很详细的介绍!!!可以依次把下面的代码段合在一起运行,也可以通过jupyter notebook分次运行第一步:基本库的导入import numpy as npimport torchimport torch.nn as nnimport torchvisionimport torchvision.transforms as transformsfrom torch.utils.data import DataLoaderimport matplotlib.pyplot as pltimport timenp.random.seed(1234)第二步:引用MNIST数据集,这里采用的是torchvision自带的MNIST数据集#这里用的是torchvision已经封装好的MINST数据集trainset=torchvision.datasets.MNIST( root='MNIST', #root是下载MNIST数据集保存的路径,可以自行修改 train=True, transform=torchvision.transforms.ToTensor(), download=True)testset=torchvision.datasets.MNIST( root='MNIST', train=False, transform=torchvision.transforms.ToTensor(), download=True)trainloader = DataLoader(dataset=trainset, batch_size=100, shuffle=True) #DataLoader是一个很好地能够帮助整理数据集的类,可以用来分批次,打乱以及多线程等操作testloader = DataLoader(dataset=testset, batch_size=100, shuffle=True)

推荐整理分享Pytorch:手把手教你搭建简单的卷积神经网络(CNN),实现MNIST数据集分类任务(pytorch教程),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch官方教程,pytorch入门教程(非常详细),pytorch 60分钟教程,pytorch 入门教程,pytorch 快速入门,pytorch 60分钟教程,pytorch怎么入门,pytorch 60分钟教程,内容如对您有帮助,希望把文章链接给更多的朋友!

下载之后利用DataLoader实例化为适合遍历的训练集和测试集,我们把其中的某一批数据进行可视化,下面是可视化的代码,其实就是利用subplot画了子图。

#可视化某一批数据train_img,train_label=next(iter(trainloader)) #iter迭代器,可以用来便利trainloader里面每一个数据,这里只迭代一次来进行可视化fig, axes = plt.subplots(10, 10, figsize=(10, 10))axes_list = []#输入到网络的图像for i in range(axes.shape[0]): for j in range(axes.shape[1]): axes[i, j].imshow(train_img[i*10+j,0,:,:],cmap="gray") #这里画出来的就是我们想输入到网络里训练的图像,与之对应的标签用来进行最后分类结果损失函数的计算 axes[i, j].axis("off")#对应的标签print(train_label)

  第三步:用pytorch搭建简单的卷积神经网络(CNN)

 这里把卷积模块单独拿出来作为一个类,看上去会舒服一点。

#卷积模块,由卷积核和激活函数组成class conv_block(nn.Module): def __init__(self,ks,ch_in,ch_out): super(conv_block,self).__init__() self.conv = nn.Sequential( nn.Conv2d(ch_in, ch_out, kernel_size=ks,stride=1,padding=1,bias=True), #二维卷积核,用于提取局部的图像信息 nn.ReLU(inplace=True), #这里用ReLU作为激活函数 nn.Conv2d(ch_out, ch_out, kernel_size=ks,stride=1,padding=1,bias=True), nn.ReLU(inplace=True), ) def forward(self,x): return self.conv(x)

下面是CNN主体部分,由上面的卷积模块和全连接分类器组合而成。这里只用了简单的几个卷积块进行堆叠,没有采用池化以及dropout的操作。主要目的是给大家简单搭建一下以便学习。

#常规CNN模块(由几个卷积模块堆叠而成)class CNN(nn.Module): def __init__(self,kernel_size,in_ch,out_ch): super(CNN, self).__init__() feature_list = [16,32,64,128,256] #代表每一层网络的特征数,扩大特征空间有助于挖掘更多的局部信息 self.conv1 = conv_block(kernel_size,in_ch,feature_list[0]) self.conv2 = conv_block(kernel_size,feature_list[0],feature_list[1]) self.conv3 = conv_block(kernel_size,feature_list[1],feature_list[2]) self.conv4 = conv_block(kernel_size,feature_list[2],feature_list[3]) self.conv5 = conv_block(kernel_size,feature_list[3],feature_list[4]) self.fc = nn.Sequential( #全连接层主要用来进行分类,整合采集的局部信息以及全局信息 nn.Linear(feature_list[4] * 28 * 28, 1024), #此处28为MINST一张图片的维度 nn.ReLU(), nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 10) ) def forward(self,x): device = x.device x1 = self.conv1(x ) x2 = self.conv2(x1) x3 = self.conv3(x2) x4 = self.conv4(x3) x5 = self.conv5(x4) x5 = x5.view(x5.size()[0], -1) #全连接层相当于做了矩阵乘法,所以这里需要将维度降维来实现矩阵的运算 out = self.fc(x5) return out第四步:训练以及模型保存

先是一些网络参数的定义,包括优化器,迭代轮数,学习率,运行硬件等等的确定。

#网络参数定义device = torch.device("cuda:4") #此处根据电脑配置进行选择,如果没有cuda就用cpu#device = torch.device("cpu")net = CNN(3,1,1).to(device = device,dtype = torch.float32)epochs = 50 #训练轮次optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-8) #使用Adam优化器criterion = nn.CrossEntropyLoss() #分类任务常用的交叉熵损失函数train_loss = []Pytorch:手把手教你搭建简单的卷积神经网络(CNN),实现MNIST数据集分类任务(pytorch教程)

然后是每一轮训练的主体:

# Begin trainingMinTrainLoss = 999for epoch in range(1,epochs+1): total_train_loss = [] net.train() start = time.time() for input_img,label in trainloader: input_img = input_img.to(device = device,dtype=torch.float32) #我们同样地,需要将我们取出来的训练集数据进行torch能够运算的格式转换 label = label.to(device = device,dtype=torch.float32) #输入和输出的格式都保持一致才能进行运算 optimizer.zero_grad() #每一次算loss前需要将之前的梯度清零,这样才不会影响后面的更新 pred_img = net(input_img) loss = criterion(pred_img,label.long()) loss.backward() optimizer.step() total_train_loss.append(loss.item()) train_loss.append(np.mean(total_train_loss)) #将一个minibatch里面的损失取平均作为这一轮的loss end = time.time() #打印当前的loss print("epochs[%3d/%3d] current loss: %.5f, time: %.3f"%(epoch,epochs,train_loss[-1],(end-start))) #打印每一轮训练的结果 if train_loss[-1]<MinTrainLoss: torch.save(net.state_dict(), "./model_min_train.pth") #保存loss最小的模型 MinTrainLoss = train_loss[-1]

以下是迭代过程:

 第五步:导入网络模型,输入某一批测试数据,查看结果

我们先来看某一批测试数据

#测试机某一批数据test_img,test_label=next(iter(testloader))fig, axes = plt.subplots(10, 10, figsize=(10, 10))axes_list = []#输入到网络的图像for i in range(axes.shape[0]): for j in range(axes.shape[1]): axes[i, j].imshow(test_img[i*10+j,0,:,:],cmap="gray") axes[i, j].axis("off")

然后将其输入到训练好的模型进行预测

#预测我拿出来的那一批数据进行展示cnn = CNN(3,1,1).to(device = device,dtype = torch.float32)cnn.load_state_dict(torch.load("./model_min_train.pth", map_location=device)) #导入我们之前已经训练好的模型cnn.eval() #评估模式test_img = test_img.to(device = device,dtype = torch.float32)test_label = test_label.to(device = device,dtype = torch.float32)pred_test = cnn(test_img) #记住,输出的结果是一个长度为10的tensortest_pred = np.argmax(pred_test.cpu().data.numpy(), axis=1) #所以我们需要对其进行最大值对应索引的处理,从而得到我们想要的预测结果#预测结果以及标签print("预测结果")print(test_pred)print("标签")print(test_label.cpu().data.numpy())

 

从预测的结果我们可以看到,整体上这么一个简单的CNN搭配全连接分类器对MNIST这一批数据分类的效果还不错。当然,我这里只用了交叉熵损失函数,并且没有计算准确率,仅供大家对于CNN学习和参考。

本文链接地址:https://www.jiuchutong.com/zhishi/300479.html 转载请保留说明!

上一篇:【GPT-3】第2章 使用 OpenAI API(gpt3 transformer)

下一篇:Java基础:笔试题(java基础笔试题在线考)

  • 华为手机怎么关闭喝水提醒(华为手机怎么关闭广告)

    华为手机怎么关闭喝水提醒(华为手机怎么关闭广告)

  • 惠普win10怎么进入安全模式(惠普win10怎么进入u盘系统)

    惠普win10怎么进入安全模式(惠普win10怎么进入u盘系统)

  • 直播间抢不到货是什么原因(直播间抢不到货有什么技巧)

    直播间抢不到货是什么原因(直播间抢不到货有什么技巧)

  • 为什么收不到特效短信(为什么收不到特定人短信)

    为什么收不到特效短信(为什么收不到特定人短信)

  • 联想2205硒鼓清零(联想2250硒鼓怎么清零)

    联想2205硒鼓清零(联想2250硒鼓怎么清零)

  • 微信信息发送成功声音怎么设置(微信信息发送成功声音设置)

    微信信息发送成功声音怎么设置(微信信息发送成功声音设置)

  • 华为荣耀9x怎么截屏(华为荣耀9x怎么取卡)

    华为荣耀9x怎么截屏(华为荣耀9x怎么取卡)

  • 去哪儿vip抢票一定能抢到吗(去哪儿vip抢票一天几次)

    去哪儿vip抢票一定能抢到吗(去哪儿vip抢票一天几次)

  • 用快捷键切换中英文输入方法时按什么键(用快捷键切换中英文输入为)

    用快捷键切换中英文输入方法时按什么键(用快捷键切换中英文输入为)

  • oppo怎么恢复出厂设置(oppo怎么恢复出厂设置方法按键)

    oppo怎么恢复出厂设置(oppo怎么恢复出厂设置方法按键)

  • 苹果蓝牙耳机怎么用不了(苹果蓝牙耳机怎么配对)

    苹果蓝牙耳机怎么用不了(苹果蓝牙耳机怎么配对)

  • 5.0ghz频段是什么意思(5.0 ghz频段)

    5.0ghz频段是什么意思(5.0 ghz频段)

  • 主板带m和不带m的区别(主板带m和不带m哪种更贵)

    主板带m和不带m的区别(主板带m和不带m哪种更贵)

  • 华为alp一al00是什么型号(华为alp-al00是华为什么型号)

    华为alp一al00是什么型号(华为alp-al00是华为什么型号)

  • 怎么看测出来的网速快慢(怎么看测出来的男孩女孩)

    怎么看测出来的网速快慢(怎么看测出来的男孩女孩)

  • 手机如何网上订酒店(手机如何网上订火车票)

    手机如何网上订酒店(手机如何网上订火车票)

  • win10运行内存怎么清理(win10运行内存怎么扩大)

    win10运行内存怎么清理(win10运行内存怎么扩大)

  • word修改不了内容(word无法修改内容)

    word修改不了内容(word无法修改内容)

  • 抖音注销后粉丝还有吗(抖音注销后粉丝灯牌还在吗)

    抖音注销后粉丝还有吗(抖音注销后粉丝灯牌还在吗)

  • 苹果电脑可以拓展内存吗(苹果电脑扩展屏幕切换快捷键)

    苹果电脑可以拓展内存吗(苹果电脑扩展屏幕切换快捷键)

  • 为啥苹果手机发语音没声音(为啥苹果手机发不出短信)

    为啥苹果手机发语音没声音(为啥苹果手机发不出短信)

  • 电脑屏幕上的图标不见了(电脑屏幕上的图标怎么调大小)

    电脑屏幕上的图标不见了(电脑屏幕上的图标怎么调大小)

  • 小米手环2支持ios吗 小米手环2支持苹果iphone手机(小米手环2支持nfc功能?)

    小米手环2支持ios吗 小米手环2支持苹果iphone手机(小米手环2支持nfc功能?)

  • win8改win7安装前的一些bios设置(win8换成win7重装系统)

    win8改win7安装前的一些bios设置(win8换成win7重装系统)

  • Vue项目部署上线全过程(保姆级教程)(vue项目部署上线 需要做哪些准备)

    Vue项目部署上线全过程(保姆级教程)(vue项目部署上线 需要做哪些准备)

  • python json保存数据的方法(pythonjson文件存储)

    python json保存数据的方法(pythonjson文件存储)

  • python return和yield有什么不同

    python return和yield有什么不同

  • 公司购买二手房可以开增值税专用发票吗
  • 税控盘的功能特点是
  • 会议服务费免税吗
  • 新公司开基本户银行选择
  • 房产税从租和从价
  • 债权人豁免债务的账务
  • 个人收到利息要交增值税吗为什么
  • 广告服务收入要计入什么科目
  • 防暑降温用品计算方法
  • 小规模纳税人增值税优惠政策
  • 虚开增值税普通罪量刑标准
  • 银行的划分标准
  • 增值税普票没有税率怎么回事
  • 应付账款从质保开始算吗
  • 公司委托其他公司办理事情
  • 车辆保险费的车船税计入什么会计科目
  • 民办非企业是否可以出资设立公司
  • 商业承兑汇票贴现为什么是短期借款
  • 小规模季超过30,增值税怎么收
  • 建筑中小企业
  • 企业销售食品过期处罚
  • 事业单位年末预算会计货币资金在贷方有余额对吗
  • 结转成本是否要等货物卖出后
  • 公司送礼分录
  • 请等待当前程序完成或更改怎么弄
  • Linux系统中矢量图ai格式怎么打开?
  • php+redis
  • 误删的文件怎么撤回
  • 上市公司股票如何套现
  • 出租车发票没有发票专用章是否能报销
  • 打印机疑难解答显示打印机问题
  • 发财树怎么养护与浇水
  • 未注销的坏账可以转出吗
  • 考核工资可以不发吗
  • session for
  • 其他流动资产是速动资产吗
  • arc架构
  • 气象数据32766
  • ubuntu busier
  • php array_merge_recursive 数组合并
  • 传统结算工具的不足有
  • python天气数据的爬取与分析
  • centos7.9 防火墙
  • 个体户个税计算公式
  • 现金科目的指定科目是什么
  • 长期待摊费用的摊销期限应该是
  • 金蝶迷你版年结账套
  • 公司两个股东变更为一个股东,需要交税么
  • 税务局核定税种需要多久
  • 改变记帐方式的原因
  • 企业会计凭证怎么写
  • 如何计提本年度工资总额
  • 会员卡收费
  • 支付宝收入什么意思
  • 划拨土地使用权管理暂行办法
  • 企业注销未分配是从注册开始吗
  • sql server中replace()函数用法解析
  • mysql增删改查面试题
  • mac下安装anaconda
  • 简单介绍linux系统有哪些主要特点?
  • win7系统怎么设置开机启动项
  • 解决Windows Server远程断开后自动
  • 如何移植操作系统
  • ezulumain.exe是病毒进程吗 ezulumain进程安全吗
  • Ubuntu GNOME 14.10的桌面升级到GNOME 3.16教程
  • deepin-win
  • 格式化不干净
  • centos7.6安装kvm
  • windows8怎么装windows10
  • linux emac
  • win10系统下怎么安装caxa2016电子图板 caxa2016电子图板安装详细图文教程
  • ExtJS 2.0实用简明教程 之Ext类库简介
  • excel实现多选
  • perl计算时间差
  • Javascript this 函数深入详解
  • linux shell有什么用
  • 用jquery实现动态添加
  • 山西电子税务局官网app
  • 孵化企业税收优惠
  • 新三步走和旧三步走的异同点
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设