位置: IT常识 - 正文

使用Pytorch实现深度学习的主要流程(pytorch技巧)

编辑:rootadmin
使用Pytorch实现深度学习的主要流程 一、使用Pytorch实现深度学习的主要流程

推荐整理分享使用Pytorch实现深度学习的主要流程(pytorch技巧),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch怎么用,pytorch例程,pytorch 简单例子,pytorch教程,pytorch基本操作,pytorch例程,pytorch怎么用,pytorch怎么用,内容如对您有帮助,希望把文章链接给更多的朋友!

使用Pytorch进行深度学习的实现流程主要包含如下几个部分: 1、预处理、后处理并确认网络的输入和输出 2、创建Dataset 3、创建DataLoader 4、创建网络模型 5、定义正向传播函数(forward) 6、定义损失函数 7、设置最优化算法 8、进行训练和验证 9、达到标准保存模型 10、加载模型使用测试数据进行测试 在使用Pytorch实现深度学习的整体流程中,首先需要对准备实现的深度学习算法从整体上进行把握。即对预处理以及后处理,网络模型的输入和输出进行确认。 创建Dataset就是将输入数据和与其对应的标签组成配对数据进行保存的类。这里将用于处理数据的预处理类的实例指定到Dataset中,并设定其在从文件中读取对象数据时,自动对输入数据进行预处理。 接下来是创建Dataloader,DataLoader是用来设定从Dataset中读取数据的具体方法的类。在深度学习中,通常都是采用小批次学习的方式,将多个数据同时从Dataset中取出,并传递给神经网络进行学习训练,DataLoader就是负责简化从Dataset中取出小批次数据这一操作的类,需要分别创建好用于训练数据以及验证数据的Dataloader。 接下来是创建网络模型,创建网络模型共有三种方式,第一种从零开始实现整个网络模型;第二种是直接载入已经训练好的网络模型,第三种是以现有训练好的网络模型为基础,将其改造为自己需要的模型。在深度学习实际应用中,大多数情况是以训练好的网络模型为基础,将其改造成符合自身需要的模型。 在成功创建网络模型之后,就需要定义网络模型的正向传播函数,forward函数,接下来要做的就是定义用于将误差值进行反向传播的损失函数,对于解决不同的任务会设置不同的损失函数。 下一步就是设定在对网络模型的连接参数进行训练时使用的优化算法,通过误差的反向传播,可以对连接参数的误差对应的梯度进行计算。优化算法就是指如何根据这一梯度值计算出连接参数的修正量具体算法。常用的优化算法有 Adam、SGD。 通过上述步骤,就完成了进行深度学习所需要的所有设置,接下来进行实际的学习和验证操作。通常以epoch为单位,对训练数据的性能和验证数据的性能进行比较,如果验证数据的性能停止提升,之后的训练都会陷入到过拟合状态,因此需要及时停止训练,提前终止网络学习的方法又称为early stopping。 学习完成之后保存我们训练得到的最优模型,之后加载模型进行推理。

二、代码实战2.1、软件包的导入以及初始设置# 实现代码的初始设置import globimport os.path as ospimport randomimport numpy as npimport jsonfrom PIL import Imagefrom tqdm import tqdmimport matplotlib.pyplot as pltimport torchimport torch.nn as nnimport torch.optim as optimimport torch.utils.data as dataimport torchvisionfrom torchvision import models, transformstorch.manual_seed(1234)np.random.seed(1234)random.seed(1234)2.2、创建Dataset# 创建DataSetclass ImageTransform(): def __init__(self, resize, mean, std): self.data_transform = { 'train': transforms.Compose([ transforms.RandomResizedCrop(resize, scale=(0.5, 1.0)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std) ]), 'val': transforms.Compose([ transforms.Resize(resize), transforms.CenterCrop(resize), transforms.ToTensor(), transforms.Normalize(mean, std) ]) } def __call__(self, img, phase='train'): return self.data_transform[phase](img)2.3、查看图像预处理前后的对比# 1.读取图像image_file_path = './data/goldenretriever-3724972_1280.jpg'img = Image.open(image_file_path)# 2.显示原图# img.show()plt.imshow(img)plt.show()# 3.预处理size = 224mean = (0.485, 0.456, 0.406)std = (0.229, 0.224, 0.225)transform = ImageTransform(size, mean, std)img_transformed = transform(img, phase='train')img_transformed = img_transformed.numpy().transpose((1, 2, 0))img_transformed = np.clip(img_transformed, 0, 1)plt.imshow(img_transformed)plt.show()

2.4、创建用于保存图片文件路径的列表变量# 用于保存蚂蚁和蜜蜂的图片文件路径列表变量def make_data_path_list(phase='train'): root_path = './data/hymenoptera_data/' target_path = osp.join(root_path + phase + '/**/*.jpg') # print(target_path) # ./data/hymenoptera_data/train/**/*.jpg # ./data/hymenoptera_data/val/**/*.jpg path_list = [] for path in glob.glob(target_path): path_list.append(path) return path_listtrain_list = make_data_path_list(phase='train')val_list = make_data_path_list(phase='val')print(train_list[:5])使用Pytorch实现深度学习的主要流程(pytorch技巧)

[‘./data/hymenoptera_data/train/bees/2638074627_6b3ae746a0.jpg’, ‘./data/hymenoptera_data/train/bees/507288830_f46e8d4cb2.jpg’, ‘./data/hymenoptera_data/train/bees/2405441001_b06c36fa72.jpg’, ‘./data/hymenoptera_data/train/bees/2962405283_22718d9617.jpg’, ‘./data/hymenoptera_data/train/bees/446296270_d9e8b93ecf.jpg’]

2.5、创建图片组成的Dataset# 构建datasetclass HymenopteraDataset(data.Dataset): def __init__(self, file_list, transform=None, phase='train'): self.file_list = file_list self.transform = transform self.phase = phase def __len__(self): return len(self.file_list) def __getitem__(self, index): img_path = self.file_list[index] img = Image.open(img_path) img_transformed = self.transform(img, self.phase) if self.phase == 'train': label = img_path[30:34] elif self.phase == 'val': label = img_path[28: 32] if label == 'ants': label = 0 elif label == 'bees': label = 1 return img_transformed, labeltrain_dataset = HymenopteraDataset( file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')val_dataset = HymenopteraDataset( file_list=val_list, transform=ImageTransform(size, mean, std), phase='val')# 查看图片和标签index = 0print(train_dataset.__getitem__(index)[0].size())print(train_dataset.__getitem__(index)[1])

torch.Size([3, 224, 224]) 1

2.6、创建Dataloader# 创建DataLoaderbatch_size = 32train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True)val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False)dataloaders_dict = {'train': train_loader, 'val': val_loader}batch_iterator = iter(dataloaders_dict['train'])inputs, labels = next(batch_iterator)print(inputs.size())print(labels)

torch.Size([32, 3, 224, 224]) tensor([0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1])

2.7、创建网络模型# 创建网络模型use_pretrained = Truenet = models.vgg16(pretrained=use_pretrained)print(net)net.classifier[6] = nn.Linear(in_features=4096, out_features=2)print(net)net.train()print('网络设置完毕: 载入已经学习完毕的权重,并设置为训练模式')

VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=1000, bias=True) ) ) VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=2, bias=True) ) ) 网络设置完毕: 载入已经学习完毕的权重,并设置为训练模式

2.8、定义损失函数# 定义损失函数criterion = nn.CrossEntropyLoss()2.9、设定最优化算法# 设定最优化算法params_to_update = []# print(net.named_parameters())update_param_names = ['classifier.6.weight', 'classifier.6.bias']for name, param in net.named_parameters(): # print(name) # print(param) if name in update_param_names: param.requires_grad = True params_to_update.append(param) print(name) else: param.requires_grad = Falseprint('=========================')print(params_to_update)# 设置最优化算法optimizer = optim.SGD(params=params_to_update, lr=0.001, momentum=0.9)print(optimizer)classifier.6.weight classifier.6.bias

[Parameter containing: tensor([[-0.0048, 0.0072, -0.0081, …, 0.0003, -0.0040, 0.0048], [ 0.0051, 0.0072, -0.0154, …, 0.0054, 0.0152, 0.0083]], requires_grad=True), Parameter containing: tensor([ 0.0108, -0.0054], requires_grad=True)] SGD ( Parameter Group 0 dampening: 0 lr: 0.001 momentum: 0.9 nesterov: False weight_decay: 0 )

2.10、训练和验证# 学习和验证的实行def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs): for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch + 1, num_epochs)) print('--------------------') for phase in ['train', 'val']: if phase == 'train': net.train() else: net.eval() epoch_loss = 0.0 epoch_corrects = 0 if epoch == 0 and phase == 'train': continue for inputs, labels in tqdm(dataloaders_dict[phase]): optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = net(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) if phase == 'train': loss.backward() optimizer.step() epoch_loss += loss.item() * inputs.size(0) epoch_corrects += torch.sum(preds == labels.data) epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset) epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset) print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))2.11、主函数if __name__ == '__main__': num_epochs = 2 train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)

2.12、保存和加载网络模型 # 保存和读取训练完毕的网络 # save_path = './weights_fine_tuning.pth' # torch.save(net.state_dict(), save_path) # 加载训练好的网络参数 load_path = './weights_fine_tuning.pth' load_weights = torch.load(load_path) net.load_state_dict(load_weights) print(net)
本文链接地址:https://www.jiuchutong.com/zhishi/300555.html 转载请保留说明!

上一篇:使用el-upload组件实现递归多文件上传(elementui的upload组件详解)

下一篇:vue3 响应式 API 之 ref(vue3 响应式ui框架)

  • iphonese2支持快充吗(苹果se2支持的快充)

    iphonese2支持快充吗(苹果se2支持的快充)

  • 相机比手机拍照好在哪(相机比手机拍照好看的原因)

    相机比手机拍照好在哪(相机比手机拍照好看的原因)

  • win10开机报0xc0000001(win10开机报0xc0000007b)

    win10开机报0xc0000001(win10开机报0xc0000007b)

  • 支付宝合种爱情树多少能量能长大(支付宝合种爱情树名称大全)

    支付宝合种爱情树多少能量能长大(支付宝合种爱情树名称大全)

  • 微信极简模式怎么设置(微信极简模式怎么打开)

    微信极简模式怎么设置(微信极简模式怎么打开)

  • nxttl00是什么型号

    nxttl00是什么型号

  • 喜马拉雅vip为什么还要付费(喜马拉雅vip为什么还要买喜点)

    喜马拉雅vip为什么还要付费(喜马拉雅vip为什么还要买喜点)

  • 探探为什么这么容易封(探探为什么这么容易被冻结)

    探探为什么这么容易封(探探为什么这么容易被冻结)

  • 小米10和10青春版有什么区别(小米10与小米10青春)

    小米10和10青春版有什么区别(小米10与小米10青春)

  • 苹果xs关机键怎么是siri(苹果xs关机快捷键)

    苹果xs关机键怎么是siri(苹果xs关机快捷键)

  • 抖音号可以电脑手机同时登录吗(抖音号电脑登录入口)

    抖音号可以电脑手机同时登录吗(抖音号电脑登录入口)

  • 锐龙r53500u相当于i几(锐龙r53500怎么样)

    锐龙r53500u相当于i几(锐龙r53500怎么样)

  • ipad6是什么处理器(ipad6用的什么处理器)

    ipad6是什么处理器(ipad6用的什么处理器)

  • 腾讯大王卡看腾讯视频免流量吗(腾讯大王卡看腾讯视频怎么激活)

    腾讯大王卡看腾讯视频免流量吗(腾讯大王卡看腾讯视频怎么激活)

  • 60fps是什么意思(60fps和30fps哪个清晰)

    60fps是什么意思(60fps和30fps哪个清晰)

  • 华为桌面圆圈怎么取消(华为桌面圆圈怎么弄掉?)

    华为桌面圆圈怎么取消(华为桌面圆圈怎么弄掉?)

  • newtv是腾讯的吗(腾讯的tv叫什么)

    newtv是腾讯的吗(腾讯的tv叫什么)

  • 一加7T Pro怎么强制关机(一加七pro怎么操作)

    一加7T Pro怎么强制关机(一加七pro怎么操作)

  • 抖音为什么关注不了了(抖音为什么关注的人自动取消了)

    抖音为什么关注不了了(抖音为什么关注的人自动取消了)

  • amoled和oled哪个伤眼(amoled和oled哪个屏幕好)

    amoled和oled哪个伤眼(amoled和oled哪个屏幕好)

  • nex3防水吗(vivo的nex3防不防水)

    nex3防水吗(vivo的nex3防不防水)

  • 被拉到亲情号怎么退出(亲情号会被拉黑吗)

    被拉到亲情号怎么退出(亲情号会被拉黑吗)

  • 锤子手机投屏怎么设置(锤子手机怎么投屏投影仪)

    锤子手机投屏怎么设置(锤子手机怎么投屏投影仪)

  • gcasServ.exe是什么进程 作用是什么 gcasServ进程查询(g++.exe error)

    gcasServ.exe是什么进程 作用是什么 gcasServ进程查询(g++.exe error)

  • 一块石炭纪蕨类化石 (© Juan Carlos Munoz/Minden Pictures)(石炭纪的树有多高)

    一块石炭纪蕨类化石 (© Juan Carlos Munoz/Minden Pictures)(石炭纪的树有多高)

  • 房产税计入管理费用了,汇算清缴怎么调
  • 个人购买二手房贷款能贷多少
  • 使用增值税发票的条件
  • 消费税的计算方法有哪三种
  • 购物税费怎么算
  • 土地使用税的纳税时间
  • 登记会计账簿的内容包括
  • 车船抵扣如何填报
  • 印花税核定征收管理办法
  • 非营利组织管理规定
  • 退货手续费账务怎么处理
  • 资本化的借款利息支出计入什么科目
  • 建筑公司现金日记账怎么填写
  • 代订机票款发票可以作为机票报销差旅吗
  • 固定资产抵扣税金算增值税吗怎么算
  • 税收缴款书税务收现专用的用途
  • 一般纳税人认定标准500万是什么时候开始执行
  • 政府补贴专项资金如何入账
  • 对方公司已注销,我公司应付款怎样支付
  • 出售股票公允价值变动损益
  • 收取专利使用费怎么支出
  • 多交增值税如何做账
  • 项目材料验收流程
  • 让记事本文件自动删除
  • 用系统自带命令行安装WIN10
  • kb4586853更新
  • 鸿蒙系统怎么开启OTG
  • 计提固定资产折旧怎么做会计科目
  • codecline
  • 谷歌浏览器如何设置主页为默认页
  • 累计盈余科目怎么填
  • 季度利润表中的营业收入怎么算
  • php制作验证码
  • 车船税发票丢失
  • 普通纳税人怎么交税
  • 兼职员工的工资怎么发放
  • 金税盘白盘怎么分发发票
  • 商品流通企业物流成本的具体构成包括()
  • 哪些情况进项税不可以抵扣?
  • 所有者权益总计是什么
  • 资产处置费用是资产类会计科目吗
  • for循环语法结构是什么
  • 土地使用权是指企业所拥有的
  • 固定资产盘盈属于其他业务收入吗
  • MySQL/Postgrsql 详细讲解如何用ODBC接口访问MySQL指南
  • 采购材料单表格
  • 股东退股如何清算表格
  • 固定制造费用差异的意义
  • 在异地施工就要在异地交税吗
  • 房地产企业如何结转成本
  • 工程收据怎么开表格
  • 房屋租赁合同印花税的税率
  • 什么是盈亏平衡法
  • 如何控制自己不磨牙
  • sql server数据库恢复
  • mysql安全性控制语句
  • sqlserver升级到2016
  • centos怎么样
  • P2PNetworking3.exe - P2PNetworking3是什么进程 有什么用
  • win10预览版21337
  • win8怎么设置桌面背景
  • cocos2dx node
  • unity3ds
  • js瀑布流效果代码
  • java程序员准备骑驴找马了,需要怎么准备
  • jquery排序上升和排序下降
  • 初识年岁尚温柔 小说 免费
  • jquery移动div到另一个div中
  • nodejs深入浅出pdf百度云
  • 可交互原型是什么
  • 网管的功能
  • node stream(流)有哪些?
  • 如何开发一个新的向量库
  • 广东电子税务局手机版
  • 山东省立第三医院地址
  • 郑州契税怎么收
  • 增值税普通发票有什么用
  • 选矿比怎么算
  • 衡阳地税局的地理位置
  • 盘州市税务局党组成员图片
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设