位置: IT常识 - 正文

基于Pytorch的MNIST手写数字识别实现(含代码+讲解)(基于Pytorch的风格转换)

编辑:rootadmin
基于Pytorch的MNIST手写数字识别实现(含代码+讲解)

推荐整理分享基于Pytorch的MNIST手写数字识别实现(含代码+讲解)(基于Pytorch的风格转换),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:基于Pytorch的LSTM模型对股价的分析与预测,基于pytorch的图像分类算法,基于Pytorch的垃圾分类毕业论文,基于Pytorch的结构化剪枝,基于Pytorch的结构化剪枝,基于Pytorch的结构化剪枝,基于Pytorch的LSTM模型对股价的分析与预测,基于Pytorch的车道线检测,内容如对您有帮助,希望把文章链接给更多的朋友!

说明:本人也是一个萌新,也在学习中,有代码里也有不完善的地方。如果有错误/讲解不清的地方请多多指出

本文代码链接:

GitHub - Michael-OvO/mnist: mnist_trained_model with torch

明确任务目标:

使用pytorch作为框架使用mnist数据集训练一个手写数字的识别

换句话说:输入为

输出: 0

比较简单直观

1. 环境搭建 

需要安装Pytorch, 具体过程因系统而异,这里也就不多赘述了

具体教程可以参考这个视频 (这个系列的P1是环境配置)

PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili【已完结!!!已完结!!!2021年5月31日已完结】本系列教程,将带你用全新的思路,快速入门PyTorch。独创的学习思路,仅此一家。个人公众号:我是土堆各种资料,请自取。代码:https://github.com/xiaotudui/PyTorch-Tutorial蚂蚁蜜蜂/练手数据集:链接: https://pan.baidu.com/s/1jZoTmoFzaTLWh4lKBHVbEA 密码https://www.bilibili.com/video/BV1hE411t7RN?share_source=copy_web

2. 基本导入import torchimport torchvisionfrom torch.utils.data import DataLoaderimport torch.nn as nnimport torch.optim as optimfrom torch.utils.tensorboard import SummaryWriterimport timeimport matplotlib.pyplot as pltimport randomfrom numpy import argmax

不多解释,导入各种需要的包

3. 基本参数定义#Basic Params-----------------------------epoch = 1learning_rate = 0.01batch_size_train = 64batch_size_test = 1000gpu = torch.cuda.is_available()momentum = 0.5

epoch是整体进行几批训练

learning rate 为学习率

随后是每批训练数据大小和测试数据大小

gpu是一个布尔值,方便没有显卡的同学可以不用cuda加速,但是有显卡的同学可以优先使用cuda

momentum 是动量,避免找不到局部最优解的尴尬情况

这些都是比较基本的网络参数

4. 数据加载

使用Dataloader加载数据,如果是第一次运行将会从网上下载数据

如果下载一直不行的话也可以从官方直接下载并放入./data目录即可

基于Pytorch的MNIST手写数字识别实现(含代码+讲解)(基于Pytorch的风格转换)

​​​​​​MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

 (有4个包都需要下载)

#Load Data-------------------------------train_loader = DataLoader(torchvision.datasets.MNIST('./data/', train=True, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ])), batch_size=batch_size_train, shuffle=True)test_loader = DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ])), batch_size=batch_size_test, shuffle=True)train_data_size = len(train_loader)test_data_size = len(test_loader)5. 网络定义

接下来是重中之重

网络的定义

这边的网络结构参考了这张图:

有了结构图,代码就很好写了, 直接对着图敲出来就好了

非常建议使用sequential直接写网络结构,会方便很多

#Define Model----------------------------class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.model = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Flatten(), nn.Linear(in_features=3136, out_features=128), nn.Linear(in_features=128, out_features=10), ) def forward(self, x): return self.model(x)if gpu: net = Net().cuda()else: net = Net()6.损失函数和优化器

交叉熵和SGD(随机梯度下降)

另外为了方便记录训练情况可以使用TensorBoard的Summary Writer

#Define Loss and Optimizer----------------if gpu: loss_fn = nn.CrossEntropyLoss().cuda()else: loss_fn = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)#Define Tensorboard-------------------writer = SummaryWriter(log_dir='logs/{}'.format(time.strftime('%Y%m%d-%H%M%S')))7. 开始训练#Train---------------------------------total_train_step = 0def train(epoch): global total_train_step total_train_step = 0 for data in train_loader: imgs,targets = data if gpu: imgs,targets = imgs.cuda(),targets.cuda() optimizer.zero_grad() outputs = net(imgs) loss = loss_fn(outputs,targets) loss.backward() optimizer.step() if total_train_step % 200 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, total_train_step, train_data_size, 100. * total_train_step / train_data_size, loss.item())) writer.add_scalar('loss', loss.item(), total_train_step) total_train_step += 1#Test---------------------------------def test(): correct = 0 total = 0 with torch.no_grad(): for data in test_loader: imgs,targets = data if gpu: imgs,targets = imgs.cuda(),targets.cuda() outputs = net(imgs) _,predicted = torch.max(outputs.data,1) total += targets.size(0) correct += (predicted == targets).sum().item() print('Test Accuracy: {}/{} ({:.0f}%)'.format(correct, total, 100.*correct/total)) return correct/total#Run----------------------------------for i in range(1,epoch+1): print("-----------------Epoch: {}-----------------".format(i)) train(i) test() writer.add_scalar('test_accuracy', test(), total_train_step) #save model torch.save(net,'model/mnist_model.pth') print('Saved model')writer.close()

注意这里必须要先在同一文件夹下创建一个叫做model的文件夹!!!不然模型目录将找不到地方保存!!!会报错!!!

Train函数作为训练,Test函数作为测试

注意每次训练需要梯度清零

模型测试时要写with torch.no_grad()

运行的过程如果有GPU加速会很快,运行结果应该如下

 正确率也还算是可以,一个epoch就能跑到98,如果不满意或者想调epoch次数可以在basic params区域直接进行修改

8. 模型验证和结果展示

小细节很多

首先是抽取样本的时候需要考虑转cuda的问题

其次如果直接将样本扔到网络里dimension不对,需要reshape

需要对结果进行argmax处理,因为结果是一个向量(有10个features,分别代表每个数字的概率),argmax会找到最大概率并输出模型的预测结果

使用matplotlib画图

#Evaluate---------------------------------model = torch.load("./model/mnist_model.pth")model.eval()print(model)fig = plt.figure(figsize=(20,20))for i in range(20): #随机抽取20个样本 index = random.randint(0,test_data_size) data = test_loader.dataset[index] #注意Cuda问题 if gpu: img = data[0].cuda() else: img = data[0] #维度不对必须要reshape img = torch.reshape(img,(1,1,28,28)) with torch.no_grad(): output = model(img) #plot the image and the predicted number fig.add_subplot(4,5,i+1) #一定要取Argmax!!! plt.title(argmax(output.data.cpu().numpy())) plt.imshow(data[0].numpy().squeeze(),cmap='gray')plt.show()

运行结果如下:

效果还是很不错的

至此我们就完成了一整个模型训练,保存,导入,验证的基本流程。

完整代码import torchimport torchvisionfrom torch.utils.data import DataLoaderimport torch.nn as nnimport torch.optim as optimfrom torch.utils.tensorboard import SummaryWriterimport timeimport matplotlib.pyplot as pltimport randomfrom numpy import argmax#Basic Params-----------------------------epoch = 1learning_rate = 0.01batch_size_train = 64batch_size_test = 1000gpu = torch.cuda.is_available()momentum = 0.5#Load Data-------------------------------train_loader = DataLoader(torchvision.datasets.MNIST('./data/', train=True, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ])), batch_size=batch_size_train, shuffle=True)test_loader = DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ])), batch_size=batch_size_test, shuffle=True)train_data_size = len(train_loader)test_data_size = len(test_loader)#Define Model----------------------------class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.model = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Flatten(), nn.Linear(in_features=3136, out_features=128), nn.Linear(in_features=128, out_features=10), ) def forward(self, x): return self.model(x)if gpu: net = Net().cuda()else: net = Net()#Define Loss and Optimizer----------------if gpu: loss_fn = nn.CrossEntropyLoss().cuda()else: loss_fn = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)#Define Tensorboard-------------------writer = SummaryWriter(log_dir='logs/{}'.format(time.strftime('%Y%m%d-%H%M%S')))#Train---------------------------------total_train_step = 0def train(epoch): global total_train_step total_train_step = 0 for data in train_loader: imgs,targets = data if gpu: imgs,targets = imgs.cuda(),targets.cuda() optimizer.zero_grad() outputs = net(imgs) loss = loss_fn(outputs,targets) loss.backward() optimizer.step() if total_train_step % 200 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, total_train_step, train_data_size, 100. * total_train_step / train_data_size, loss.item())) writer.add_scalar('loss', loss.item(), total_train_step) total_train_step += 1#Test---------------------------------def test(): correct = 0 total = 0 with torch.no_grad(): for data in test_loader: imgs,targets = data if gpu: imgs,targets = imgs.cuda(),targets.cuda() outputs = net(imgs) _,predicted = torch.max(outputs.data,1) total += targets.size(0) correct += (predicted == targets).sum().item() print('Test Accuracy: {}/{} ({:.0f}%)'.format(correct, total, 100.*correct/total)) return correct/total#Run----------------------------------for i in range(1,epoch+1): print("-----------------Epoch: {}-----------------".format(i)) train(i) test() writer.add_scalar('test_accuracy', test(), total_train_step) #save model torch.save(net,'model/mnist_model.pth') print('Saved model')writer.close()#Evaluate---------------------------------model = torch.load("./model/mnist_model.pth")model.eval()print(model)fig = plt.figure(figsize=(20,20))for i in range(20): #random number index = random.randint(0,test_data_size) data = test_loader.dataset[index] if gpu: img = data[0].cuda() else: img = data[0] img = torch.reshape(img,(1,1,28,28)) with torch.no_grad(): output = model(img) #plot the image and the predicted number fig.add_subplot(4,5,i+1) plt.title(argmax(output.data.cpu().numpy())) plt.imshow(data[0].numpy().squeeze(),cmap='gray')plt.show()
本文链接地址:https://www.jiuchutong.com/zhishi/300439.html 转载请保留说明!

上一篇:【yolov6系列一】深度解析网络架构(yolov5官方)

下一篇:【Spring】IOC,你真的懂了吗?(spring ioc di aop)

  • 苹果辅助球怎么打开(苹果辅助球怎么移动)

    苹果辅助球怎么打开(苹果辅助球怎么移动)

  • 华为手机nova6灭屏显示时间(华为手机nova6怎么没有灭屏功能)

    华为手机nova6灭屏显示时间(华为手机nova6怎么没有灭屏功能)

  • 华为手机桌面布局已锁定怎么办(华为手机桌面布局锁定怎么开)

    华为手机桌面布局已锁定怎么办(华为手机桌面布局锁定怎么开)

  • 8寸大屏幕手机有哪些(热门8寸大屏手机大全)

    8寸大屏幕手机有哪些(热门8寸大屏手机大全)

  • 计算机与电子乐器之间通过什么进行数据交换(计算机电子乐谱)

    计算机与电子乐器之间通过什么进行数据交换(计算机电子乐谱)

  • 华为手机如何解除安全模式(华为手机如何解除呼叫限制)

    华为手机如何解除安全模式(华为手机如何解除呼叫限制)

  • 苹果11为什么面容解锁突然用不了了呢(苹果11为什么面容解锁有时不好使)

    苹果11为什么面容解锁突然用不了了呢(苹果11为什么面容解锁有时不好使)

  • 淘宝退款成功了为什么还有确认收货(淘宝退款成功了但是货收到了怎么办)

    淘宝退款成功了为什么还有确认收货(淘宝退款成功了但是货收到了怎么办)

  • 路由模式和桥接模式的区别(路由模式和桥接模式和中继模式)

    路由模式和桥接模式的区别(路由模式和桥接模式和中继模式)

  • 怎样知道小红书限流了(怎样知道小红书笔记审核通过)

    怎样知道小红书限流了(怎样知道小红书笔记审核通过)

  • 苹果手机换过屏幕有什么影响(苹果手机换过屏幕值得买吗)

    苹果手机换过屏幕有什么影响(苹果手机换过屏幕值得买吗)

  • vivoy67上市时间(vivoy67a上市时间)

    vivoy67上市时间(vivoy67a上市时间)

  • 绑定淘宝客pid是什么意思(淘宝客pid是几位数字在哪里找)

    绑定淘宝客pid是什么意思(淘宝客pid是几位数字在哪里找)

  • word文档里面怎么画线(word文档里面怎么打勾勾)

    word文档里面怎么画线(word文档里面怎么打勾勾)

  • 手机有个耳机标志,手机没声音(手机有个耳机标志,手机没声音怎么调)

    手机有个耳机标志,手机没声音(手机有个耳机标志,手机没声音怎么调)

  • iphone蓝牙搜不到beats(iPhone蓝牙搜不到小米手表)

    iphone蓝牙搜不到beats(iPhone蓝牙搜不到小米手表)

  • 饿了么怎么帮外地订餐(饿了么怎么帮外地人点餐)

    饿了么怎么帮外地订餐(饿了么怎么帮外地人点餐)

  • 爱奇艺怎么不能调清晰度了(爱奇艺怎么不能小窗口播放)

    爱奇艺怎么不能调清晰度了(爱奇艺怎么不能小窗口播放)

  • 支付宝步数和手机不一致(支付宝步数和手机步数不一样)

    支付宝步数和手机不一致(支付宝步数和手机步数不一样)

  • 荣耀20pro有红外吗(荣耀80pro有红外功能吗)

    荣耀20pro有红外吗(荣耀80pro有红外功能吗)

  • ubuntu卸载软件(Ubuntu卸载软件包)

    ubuntu卸载软件(Ubuntu卸载软件包)

  • 天猫精灵曲奇连不上网(天猫精灵曲奇连不上怎么办)

    天猫精灵曲奇连不上网(天猫精灵曲奇连不上怎么办)

  • 笔记本频繁自动关机(笔记本频繁自动关机怎么解决)

    笔记本频繁自动关机(笔记本频繁自动关机怎么解决)

  • xr怎么设置电池百分比(苹果xr如何设置电池)

    xr怎么设置电池百分比(苹果xr如何设置电池)

  • 如何通过U盘重装WIN7系统?(如何通过u盘重启)

    如何通过U盘重装WIN7系统?(如何通过u盘重启)

  • ftptop命令  显示服务器的连接状态(ftp命令行)

    ftptop命令 显示服务器的连接状态(ftp命令行)

  • 印花税办理流程
  • 营业成本和生产成本的公式
  • 围挡属于什么类型
  • 纳税人性质怎么改
  • 预缴的增值税怎么算
  • 民间非营利组织会计制度最新版
  • 已交增值税如何做账
  • 营改增前未完工的老项目可以开专票吗
  • 个体工商户需要每个月报税吗
  • 未取得发票如何进应付暂估科目
  • 土地增值税清算扣除项目
  • 公司对项目管理方式
  • 怎么填报清算所得税申报表?
  • 从外面买回来的菜怎么消毒
  • 出口预收货款发生的时间和报表上的时间不一样怎么办
  • 按计划成本发出原材料怎么算
  • 核定征收所得税税率
  • 公司罚款作为一种对过错方式的处罚
  • 退货后发票还能拿去抵税吗
  • 增值税发票不见了可以重开吗
  • 附加税是地方还是国家的
  • 机动车类专用发票
  • 个人租车给公司租金多少合适
  • 法人变更了还用变更发票领用本吗
  • 个体户怎么申请核定征收
  • 预付股权转让款如何处理
  • 高新企业所得税税率10%
  • 个体工商户起征点10万执行时间
  • 进项税和销项税的分录
  • steam打开速度
  • 关联企业需要计提坏账
  • win10桌面2怎么使用
  • 如何实现php图片打印
  • 订金账务处理
  • 劳动合同到期补偿金怎么算
  • 企业租赁房屋怎么开发票
  • 笔记本电池保养注意事项
  • 短期投资需要结转吗
  • 划入账户金额
  • 手把手教你用气焊视频
  • 汇兑损益计入什么科目
  • vue 富文本编辑框
  • python tle
  • 什么发票才能做账务处理
  • 库存盘点差异会计分录
  • 一般纳税人在哪里报税
  • 营业执照更换法人需要哪些手续
  • 免税增值税纳税申报表怎么填
  • 人力资源公司代办
  • 社保基数怎么申请下调
  • 员工借现金分录
  • 摄影行业开票
  • mysql 缓冲区
  • 固定资产入账怎么做凭证和入资产卡片?
  • 非税收入一般缴款书查询
  • 库存商品怎么做表格
  • 汽车加油费属于交通费用吗
  • 收到退回多付的材料退款
  • 将sql语句的执行状态传递给主语言的是
  • 数据库备份sqlserver
  • macbookpro如何扫描
  • 火狐firefox浏览器华为
  • mac u 盘启动
  • bootcamp不用u盘
  • centos ulimit
  • win7 64位纯净版图标变成了一样该怎么办?win7旗舰版图标变成一样的解决方法
  • win10怎么这只让任务栏图标居中显示?
  • node的fs模块
  • android开发环境配置
  • jquery如何获取input的值
  • jquery网站开发
  • jQuery ajax的功能实现方法详解
  • python回归结果输出
  • vue 父子组件通信
  • jquery.min.js源代码
  • javascript的
  • jquery校验form表单
  • js设置延时执行
  • 土地增值税法定扣除项目
  • 二手商铺买卖
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设