位置: IT常识 - 正文

基于Pytorch的MNIST手写数字识别实现(含代码+讲解)(基于Pytorch的风格转换)

编辑:rootadmin
基于Pytorch的MNIST手写数字识别实现(含代码+讲解)

推荐整理分享基于Pytorch的MNIST手写数字识别实现(含代码+讲解)(基于Pytorch的风格转换),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:基于Pytorch的LSTM模型对股价的分析与预测,基于pytorch的图像分类算法,基于Pytorch的垃圾分类毕业论文,基于Pytorch的结构化剪枝,基于Pytorch的结构化剪枝,基于Pytorch的结构化剪枝,基于Pytorch的LSTM模型对股价的分析与预测,基于Pytorch的车道线检测,内容如对您有帮助,希望把文章链接给更多的朋友!

说明:本人也是一个萌新,也在学习中,有代码里也有不完善的地方。如果有错误/讲解不清的地方请多多指出

本文代码链接:

GitHub - Michael-OvO/mnist: mnist_trained_model with torch

明确任务目标:

使用pytorch作为框架使用mnist数据集训练一个手写数字的识别

换句话说:输入为

输出: 0

比较简单直观

1. 环境搭建 

需要安装Pytorch, 具体过程因系统而异,这里也就不多赘述了

具体教程可以参考这个视频 (这个系列的P1是环境配置)

PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili【已完结!!!已完结!!!2021年5月31日已完结】本系列教程,将带你用全新的思路,快速入门PyTorch。独创的学习思路,仅此一家。个人公众号:我是土堆各种资料,请自取。代码:https://github.com/xiaotudui/PyTorch-Tutorial蚂蚁蜜蜂/练手数据集:链接: https://pan.baidu.com/s/1jZoTmoFzaTLWh4lKBHVbEA 密码https://www.bilibili.com/video/BV1hE411t7RN?share_source=copy_web

2. 基本导入import torchimport torchvisionfrom torch.utils.data import DataLoaderimport torch.nn as nnimport torch.optim as optimfrom torch.utils.tensorboard import SummaryWriterimport timeimport matplotlib.pyplot as pltimport randomfrom numpy import argmax

不多解释,导入各种需要的包

3. 基本参数定义#Basic Params-----------------------------epoch = 1learning_rate = 0.01batch_size_train = 64batch_size_test = 1000gpu = torch.cuda.is_available()momentum = 0.5

epoch是整体进行几批训练

learning rate 为学习率

随后是每批训练数据大小和测试数据大小

gpu是一个布尔值,方便没有显卡的同学可以不用cuda加速,但是有显卡的同学可以优先使用cuda

momentum 是动量,避免找不到局部最优解的尴尬情况

这些都是比较基本的网络参数

4. 数据加载

使用Dataloader加载数据,如果是第一次运行将会从网上下载数据

如果下载一直不行的话也可以从官方直接下载并放入./data目录即可

基于Pytorch的MNIST手写数字识别实现(含代码+讲解)(基于Pytorch的风格转换)

​​​​​​MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

 (有4个包都需要下载)

#Load Data-------------------------------train_loader = DataLoader(torchvision.datasets.MNIST('./data/', train=True, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ])), batch_size=batch_size_train, shuffle=True)test_loader = DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ])), batch_size=batch_size_test, shuffle=True)train_data_size = len(train_loader)test_data_size = len(test_loader)5. 网络定义

接下来是重中之重

网络的定义

这边的网络结构参考了这张图:

有了结构图,代码就很好写了, 直接对着图敲出来就好了

非常建议使用sequential直接写网络结构,会方便很多

#Define Model----------------------------class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.model = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Flatten(), nn.Linear(in_features=3136, out_features=128), nn.Linear(in_features=128, out_features=10), ) def forward(self, x): return self.model(x)if gpu: net = Net().cuda()else: net = Net()6.损失函数和优化器

交叉熵和SGD(随机梯度下降)

另外为了方便记录训练情况可以使用TensorBoard的Summary Writer

#Define Loss and Optimizer----------------if gpu: loss_fn = nn.CrossEntropyLoss().cuda()else: loss_fn = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)#Define Tensorboard-------------------writer = SummaryWriter(log_dir='logs/{}'.format(time.strftime('%Y%m%d-%H%M%S')))7. 开始训练#Train---------------------------------total_train_step = 0def train(epoch): global total_train_step total_train_step = 0 for data in train_loader: imgs,targets = data if gpu: imgs,targets = imgs.cuda(),targets.cuda() optimizer.zero_grad() outputs = net(imgs) loss = loss_fn(outputs,targets) loss.backward() optimizer.step() if total_train_step % 200 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, total_train_step, train_data_size, 100. * total_train_step / train_data_size, loss.item())) writer.add_scalar('loss', loss.item(), total_train_step) total_train_step += 1#Test---------------------------------def test(): correct = 0 total = 0 with torch.no_grad(): for data in test_loader: imgs,targets = data if gpu: imgs,targets = imgs.cuda(),targets.cuda() outputs = net(imgs) _,predicted = torch.max(outputs.data,1) total += targets.size(0) correct += (predicted == targets).sum().item() print('Test Accuracy: {}/{} ({:.0f}%)'.format(correct, total, 100.*correct/total)) return correct/total#Run----------------------------------for i in range(1,epoch+1): print("-----------------Epoch: {}-----------------".format(i)) train(i) test() writer.add_scalar('test_accuracy', test(), total_train_step) #save model torch.save(net,'model/mnist_model.pth') print('Saved model')writer.close()

注意这里必须要先在同一文件夹下创建一个叫做model的文件夹!!!不然模型目录将找不到地方保存!!!会报错!!!

Train函数作为训练,Test函数作为测试

注意每次训练需要梯度清零

模型测试时要写with torch.no_grad()

运行的过程如果有GPU加速会很快,运行结果应该如下

 正确率也还算是可以,一个epoch就能跑到98,如果不满意或者想调epoch次数可以在basic params区域直接进行修改

8. 模型验证和结果展示

小细节很多

首先是抽取样本的时候需要考虑转cuda的问题

其次如果直接将样本扔到网络里dimension不对,需要reshape

需要对结果进行argmax处理,因为结果是一个向量(有10个features,分别代表每个数字的概率),argmax会找到最大概率并输出模型的预测结果

使用matplotlib画图

#Evaluate---------------------------------model = torch.load("./model/mnist_model.pth")model.eval()print(model)fig = plt.figure(figsize=(20,20))for i in range(20): #随机抽取20个样本 index = random.randint(0,test_data_size) data = test_loader.dataset[index] #注意Cuda问题 if gpu: img = data[0].cuda() else: img = data[0] #维度不对必须要reshape img = torch.reshape(img,(1,1,28,28)) with torch.no_grad(): output = model(img) #plot the image and the predicted number fig.add_subplot(4,5,i+1) #一定要取Argmax!!! plt.title(argmax(output.data.cpu().numpy())) plt.imshow(data[0].numpy().squeeze(),cmap='gray')plt.show()

运行结果如下:

效果还是很不错的

至此我们就完成了一整个模型训练,保存,导入,验证的基本流程。

完整代码import torchimport torchvisionfrom torch.utils.data import DataLoaderimport torch.nn as nnimport torch.optim as optimfrom torch.utils.tensorboard import SummaryWriterimport timeimport matplotlib.pyplot as pltimport randomfrom numpy import argmax#Basic Params-----------------------------epoch = 1learning_rate = 0.01batch_size_train = 64batch_size_test = 1000gpu = torch.cuda.is_available()momentum = 0.5#Load Data-------------------------------train_loader = DataLoader(torchvision.datasets.MNIST('./data/', train=True, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ])), batch_size=batch_size_train, shuffle=True)test_loader = DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ])), batch_size=batch_size_test, shuffle=True)train_data_size = len(train_loader)test_data_size = len(test_loader)#Define Model----------------------------class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.model = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Flatten(), nn.Linear(in_features=3136, out_features=128), nn.Linear(in_features=128, out_features=10), ) def forward(self, x): return self.model(x)if gpu: net = Net().cuda()else: net = Net()#Define Loss and Optimizer----------------if gpu: loss_fn = nn.CrossEntropyLoss().cuda()else: loss_fn = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)#Define Tensorboard-------------------writer = SummaryWriter(log_dir='logs/{}'.format(time.strftime('%Y%m%d-%H%M%S')))#Train---------------------------------total_train_step = 0def train(epoch): global total_train_step total_train_step = 0 for data in train_loader: imgs,targets = data if gpu: imgs,targets = imgs.cuda(),targets.cuda() optimizer.zero_grad() outputs = net(imgs) loss = loss_fn(outputs,targets) loss.backward() optimizer.step() if total_train_step % 200 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, total_train_step, train_data_size, 100. * total_train_step / train_data_size, loss.item())) writer.add_scalar('loss', loss.item(), total_train_step) total_train_step += 1#Test---------------------------------def test(): correct = 0 total = 0 with torch.no_grad(): for data in test_loader: imgs,targets = data if gpu: imgs,targets = imgs.cuda(),targets.cuda() outputs = net(imgs) _,predicted = torch.max(outputs.data,1) total += targets.size(0) correct += (predicted == targets).sum().item() print('Test Accuracy: {}/{} ({:.0f}%)'.format(correct, total, 100.*correct/total)) return correct/total#Run----------------------------------for i in range(1,epoch+1): print("-----------------Epoch: {}-----------------".format(i)) train(i) test() writer.add_scalar('test_accuracy', test(), total_train_step) #save model torch.save(net,'model/mnist_model.pth') print('Saved model')writer.close()#Evaluate---------------------------------model = torch.load("./model/mnist_model.pth")model.eval()print(model)fig = plt.figure(figsize=(20,20))for i in range(20): #random number index = random.randint(0,test_data_size) data = test_loader.dataset[index] if gpu: img = data[0].cuda() else: img = data[0] img = torch.reshape(img,(1,1,28,28)) with torch.no_grad(): output = model(img) #plot the image and the predicted number fig.add_subplot(4,5,i+1) plt.title(argmax(output.data.cpu().numpy())) plt.imshow(data[0].numpy().squeeze(),cmap='gray')plt.show()
本文链接地址:https://www.jiuchutong.com/zhishi/300439.html 转载请保留说明!

上一篇:【yolov6系列一】深度解析网络架构(yolov5官方)

下一篇:【Spring】IOC,你真的懂了吗?(spring ioc di aop)

  • 关于教师节的感恩的话(关于教师节的感谢语)(关于教师节的感恩词)

    关于教师节的感恩的话(关于教师节的感谢语)(关于教师节的感恩词)

  • 石头g10尘盒怎么打开(石头机器人尘盒打开)

    石头g10尘盒怎么打开(石头机器人尘盒打开)

  • 抖音橱窗和小黄车有什么区别(抖音橱窗和小黄车没有银行卡能不能提现)

    抖音橱窗和小黄车有什么区别(抖音橱窗和小黄车没有银行卡能不能提现)

  • 红米k30s至尊纪念版屏幕是多大尺寸的(红米k30s至尊纪念版是什么屏幕)

    红米k30s至尊纪念版屏幕是多大尺寸的(红米k30s至尊纪念版是什么屏幕)

  • qq深色模式如何关闭(qq深色模式怎么开)

    qq深色模式如何关闭(qq深色模式怎么开)

  • 无法复制磁盘被写保护请去掉写保护(无法复制磁盘被写保护怎么解除)

    无法复制磁盘被写保护请去掉写保护(无法复制磁盘被写保护怎么解除)

  • 苹果12w和18w充电器对比(苹果12用18瓦和20瓦哪个充电器好)

    苹果12w和18w充电器对比(苹果12用18瓦和20瓦哪个充电器好)

  • 宏碁是国产吗(宏碁是国产吗知乎)

    宏碁是国产吗(宏碁是国产吗知乎)

  • cpu超频ring倍频要调吗(cpu超频 倍频)

    cpu超频ring倍频要调吗(cpu超频 倍频)

  • edi是什么之间的数据传输(什么是edi,它有哪些特点?)

    edi是什么之间的数据传输(什么是edi,它有哪些特点?)

  • 华为手机p40 pro什么时候上市(华为手机p40pro屏幕碎了维修大概要多少钱)

    华为手机p40 pro什么时候上市(华为手机p40pro屏幕碎了维修大概要多少钱)

  • 荣耀20青春版指纹在哪里(荣耀20青春版指纹解锁怎么用)

    荣耀20青春版指纹在哪里(荣耀20青春版指纹解锁怎么用)

  • 网络接入已满是什么意思(网络接入已满是手机问题吗)

    网络接入已满是什么意思(网络接入已满是手机问题吗)

  • id暂时禁止免费获取app怎么办(id暂时禁止免费获取app要多久)

    id暂时禁止免费获取app怎么办(id暂时禁止免费获取app要多久)

  • 程序和进程是一一对应的吗(程序和进程是一对应的第一个程序只对应一个进程)

    程序和进程是一一对应的吗(程序和进程是一对应的第一个程序只对应一个进程)

  • 到了国外还能用微信吗(到了国外还能用流量吗)

    到了国外还能用微信吗(到了国外还能用流量吗)

  • 调节电脑亮度的快捷键是什么(调节电脑亮度的按键)

    调节电脑亮度的快捷键是什么(调节电脑亮度的按键)

  • 拼多多保留5个团啥意思(拼多多保留5个团怎么操作)

    拼多多保留5个团啥意思(拼多多保留5个团怎么操作)

  • ug怎么把斜工件摆正(ug怎么把斜工件摆正视图方向)

    ug怎么把斜工件摆正(ug怎么把斜工件摆正视图方向)

  • 电脑买回来要做些什么(电脑买回来要做系统吗)

    电脑买回来要做些什么(电脑买回来要做系统吗)

  • 想用U盘装系统索尼笔记本如何进Bios设置U盘启动(想用u盘装系统怎么弄)

    想用U盘装系统索尼笔记本如何进Bios设置U盘启动(想用u盘装系统怎么弄)

  • 帝国cms模板文章列表分页的下划线如何修改(帝国cms模板文件放在哪里)

    帝国cms模板文章列表分页的下划线如何修改(帝国cms模板文件放在哪里)

  • 个税汇缴常见问题
  • 首付款计提税金吗
  • 税务师都有什么科目
  • 固定资产计提折旧计入什么科目
  • 增值税技术维护费每年都可以抵减吗?
  • 什么是未投入使用的固定资产
  • 招待客人的住宿费能抵扣吗?
  • 自建厂房出售如何计算所得税
  • 合法有效的凭证
  • 购买旧资产如何入账
  • 私车公用产生的过路费开个人发票还是公司发票
  • 金税三期社保费管理客户端v1.0.088(生产环境)
  • 子公司之间可以相互交易吗
  • 0退税产品怎么征税
  • 没有收入要做应交税费的会计分录
  • 企业开发票的人员要经过培训吗?
  • 劳务派遣公司发放工资是按照劳务报酬嘛
  • 个人出租仓库需交税吗
  • 公司招待客户买的水果怎么入账
  • 不动产权时间怎么确认
  • 初级备考需要多长时间
  • 委托加工白酒的计税依据
  • 年初预提费用
  • 全年平均职工人数按季度平均公式
  • 按揭的车可以只买交强险吗
  • win11怎么改名
  • 打开游戏时总是出现需要新应用打开此MS
  • 卸载软件怎么清理干净
  • 税务自查补缴税款的申报表在哪里找
  • 怎样让鼠标变得好看些
  • 小企业会计准则没有以前年度损益调整科目
  • 收到保险公司车辆保险发票会计分录
  • 事业单位需要交企业所得税吗
  • 进程控制块PCB不包括( )
  • 公司的净资产总值怎么算
  • 可供出售金融资产现在叫什么
  • 公允价值变动损益借贷方向增减
  • 兼职如何交税款
  • 埃托沙国家公园发展观兽旅游的优势条件
  • 购买股票会计分录怎么写
  • wordpress项目开发
  • 穆尔官网
  • php进程数设置
  • 城市维护建设税减免税优惠政策
  • ChatGPT频频发疯!马斯克警告:AI将毁灭人类
  • 华为od测试岗机试需要怎么准备
  • 税务端系统返回错误信息f50006
  • 微擎框架安装教程
  • 多线程并发python
  • 无追保理是什么意思
  • 秸秆回收加工项目
  • linux mysql忘记密码的多种解决或Access denied for user 'root'@'localhost'
  • 房屋出租收到的发票
  • 库存现金的使用限额规定
  • 营改增有关事项的规定
  • 费用报销的程序是什么
  • 组织机构代码证和统一社会信用代码的关系
  • 已启动申报比对异常怎样才能作废,还没过税期
  • 什么是企业合并?
  • 发票专用章需要备案吗?
  • 企业流动负债比率多少算正常
  • SQL Server的FileStream和FileTable深入剖析
  • CentOS6 32/64位安装Adobe Flash Player组件的方法
  • vmware虚拟机不能识别iso
  • windows设置
  • xp系统怎么取消密码怎么设置
  • windows10预览版怎么样
  • oodag.exe - oodag是什么进程 有什么作用
  • 装win8.1
  • 如何修改excel数据显示格式
  • css实现下拉菜单的思路是
  • perl中哈希如何赋值
  • Unity3D游戏开发基础
  • windows ipython
  • 带领大家学习javascript基础篇(一)之基本概念
  • 什么东西的海关不能寄
  • 买房送地下室土地可以吗
  • 西安车辆购置税缴纳多少
  • 瑞士州税
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设