位置: IT常识 - 正文

使用Pytorch实现深度学习的主要流程(pytorch技巧)

发布时间:2024-01-17
使用Pytorch实现深度学习的主要流程 一、使用Pytorch实现深度学习的主要流程

推荐整理分享使用Pytorch实现深度学习的主要流程(pytorch技巧),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch怎么用,pytorch例程,pytorch 简单例子,pytorch教程,pytorch基本操作,pytorch例程,pytorch怎么用,pytorch怎么用,内容如对您有帮助,希望把文章链接给更多的朋友!

使用Pytorch进行深度学习的实现流程主要包含如下几个部分: 1、预处理、后处理并确认网络的输入和输出 2、创建Dataset 3、创建DataLoader 4、创建网络模型 5、定义正向传播函数(forward) 6、定义损失函数 7、设置最优化算法 8、进行训练和验证 9、达到标准保存模型 10、加载模型使用测试数据进行测试 在使用Pytorch实现深度学习的整体流程中,首先需要对准备实现的深度学习算法从整体上进行把握。即对预处理以及后处理,网络模型的输入和输出进行确认。 创建Dataset就是将输入数据和与其对应的标签组成配对数据进行保存的类。这里将用于处理数据的预处理类的实例指定到Dataset中,并设定其在从文件中读取对象数据时,自动对输入数据进行预处理。 接下来是创建Dataloader,DataLoader是用来设定从Dataset中读取数据的具体方法的类。在深度学习中,通常都是采用小批次学习的方式,将多个数据同时从Dataset中取出,并传递给神经网络进行学习训练,DataLoader就是负责简化从Dataset中取出小批次数据这一操作的类,需要分别创建好用于训练数据以及验证数据的Dataloader。 接下来是创建网络模型,创建网络模型共有三种方式,第一种从零开始实现整个网络模型;第二种是直接载入已经训练好的网络模型,第三种是以现有训练好的网络模型为基础,将其改造为自己需要的模型。在深度学习实际应用中,大多数情况是以训练好的网络模型为基础,将其改造成符合自身需要的模型。 在成功创建网络模型之后,就需要定义网络模型的正向传播函数,forward函数,接下来要做的就是定义用于将误差值进行反向传播的损失函数,对于解决不同的任务会设置不同的损失函数。 下一步就是设定在对网络模型的连接参数进行训练时使用的优化算法,通过误差的反向传播,可以对连接参数的误差对应的梯度进行计算。优化算法就是指如何根据这一梯度值计算出连接参数的修正量具体算法。常用的优化算法有 Adam、SGD。 通过上述步骤,就完成了进行深度学习所需要的所有设置,接下来进行实际的学习和验证操作。通常以epoch为单位,对训练数据的性能和验证数据的性能进行比较,如果验证数据的性能停止提升,之后的训练都会陷入到过拟合状态,因此需要及时停止训练,提前终止网络学习的方法又称为early stopping。 学习完成之后保存我们训练得到的最优模型,之后加载模型进行推理。

二、代码实战2.1、软件包的导入以及初始设置# 实现代码的初始设置import globimport os.path as ospimport randomimport numpy as npimport jsonfrom PIL import Imagefrom tqdm import tqdmimport matplotlib.pyplot as pltimport torchimport torch.nn as nnimport torch.optim as optimimport torch.utils.data as dataimport torchvisionfrom torchvision import models, transformstorch.manual_seed(1234)np.random.seed(1234)random.seed(1234)2.2、创建Dataset# 创建DataSetclass ImageTransform(): def __init__(self, resize, mean, std): self.data_transform = { 'train': transforms.Compose([ transforms.RandomResizedCrop(resize, scale=(0.5, 1.0)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std) ]), 'val': transforms.Compose([ transforms.Resize(resize), transforms.CenterCrop(resize), transforms.ToTensor(), transforms.Normalize(mean, std) ]) } def __call__(self, img, phase='train'): return self.data_transform[phase](img)2.3、查看图像预处理前后的对比# 1.读取图像image_file_path = './data/goldenretriever-3724972_1280.jpg'img = Image.open(image_file_path)# 2.显示原图# img.show()plt.imshow(img)plt.show()# 3.预处理size = 224mean = (0.485, 0.456, 0.406)std = (0.229, 0.224, 0.225)transform = ImageTransform(size, mean, std)img_transformed = transform(img, phase='train')img_transformed = img_transformed.numpy().transpose((1, 2, 0))img_transformed = np.clip(img_transformed, 0, 1)plt.imshow(img_transformed)plt.show()

2.4、创建用于保存图片文件路径的列表变量# 用于保存蚂蚁和蜜蜂的图片文件路径列表变量def make_data_path_list(phase='train'): root_path = './data/hymenoptera_data/' target_path = osp.join(root_path + phase + '/**/*.jpg') # print(target_path) # ./data/hymenoptera_data/train/**/*.jpg # ./data/hymenoptera_data/val/**/*.jpg path_list = [] for path in glob.glob(target_path): path_list.append(path) return path_listtrain_list = make_data_path_list(phase='train')val_list = make_data_path_list(phase='val')print(train_list[:5])使用Pytorch实现深度学习的主要流程(pytorch技巧)

[‘./data/hymenoptera_data/train/bees/2638074627_6b3ae746a0.jpg’, ‘./data/hymenoptera_data/train/bees/507288830_f46e8d4cb2.jpg’, ‘./data/hymenoptera_data/train/bees/2405441001_b06c36fa72.jpg’, ‘./data/hymenoptera_data/train/bees/2962405283_22718d9617.jpg’, ‘./data/hymenoptera_data/train/bees/446296270_d9e8b93ecf.jpg’]

2.5、创建图片组成的Dataset# 构建datasetclass HymenopteraDataset(data.Dataset): def __init__(self, file_list, transform=None, phase='train'): self.file_list = file_list self.transform = transform self.phase = phase def __len__(self): return len(self.file_list) def __getitem__(self, index): img_path = self.file_list[index] img = Image.open(img_path) img_transformed = self.transform(img, self.phase) if self.phase == 'train': label = img_path[30:34] elif self.phase == 'val': label = img_path[28: 32] if label == 'ants': label = 0 elif label == 'bees': label = 1 return img_transformed, labeltrain_dataset = HymenopteraDataset( file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')val_dataset = HymenopteraDataset( file_list=val_list, transform=ImageTransform(size, mean, std), phase='val')# 查看图片和标签index = 0print(train_dataset.__getitem__(index)[0].size())print(train_dataset.__getitem__(index)[1])

torch.Size([3, 224, 224]) 1

2.6、创建Dataloader# 创建DataLoaderbatch_size = 32train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True)val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False)dataloaders_dict = {'train': train_loader, 'val': val_loader}batch_iterator = iter(dataloaders_dict['train'])inputs, labels = next(batch_iterator)print(inputs.size())print(labels)

torch.Size([32, 3, 224, 224]) tensor([0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1])

2.7、创建网络模型# 创建网络模型use_pretrained = Truenet = models.vgg16(pretrained=use_pretrained)print(net)net.classifier[6] = nn.Linear(in_features=4096, out_features=2)print(net)net.train()print('网络设置完毕: 载入已经学习完毕的权重,并设置为训练模式')

VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=1000, bias=True) ) ) VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=2, bias=True) ) ) 网络设置完毕: 载入已经学习完毕的权重,并设置为训练模式

2.8、定义损失函数# 定义损失函数criterion = nn.CrossEntropyLoss()2.9、设定最优化算法# 设定最优化算法params_to_update = []# print(net.named_parameters())update_param_names = ['classifier.6.weight', 'classifier.6.bias']for name, param in net.named_parameters(): # print(name) # print(param) if name in update_param_names: param.requires_grad = True params_to_update.append(param) print(name) else: param.requires_grad = Falseprint('=========================')print(params_to_update)# 设置最优化算法optimizer = optim.SGD(params=params_to_update, lr=0.001, momentum=0.9)print(optimizer)classifier.6.weight classifier.6.bias

[Parameter containing: tensor([[-0.0048, 0.0072, -0.0081, …, 0.0003, -0.0040, 0.0048], [ 0.0051, 0.0072, -0.0154, …, 0.0054, 0.0152, 0.0083]], requires_grad=True), Parameter containing: tensor([ 0.0108, -0.0054], requires_grad=True)] SGD ( Parameter Group 0 dampening: 0 lr: 0.001 momentum: 0.9 nesterov: False weight_decay: 0 )

2.10、训练和验证# 学习和验证的实行def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs): for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch + 1, num_epochs)) print('--------------------') for phase in ['train', 'val']: if phase == 'train': net.train() else: net.eval() epoch_loss = 0.0 epoch_corrects = 0 if epoch == 0 and phase == 'train': continue for inputs, labels in tqdm(dataloaders_dict[phase]): optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = net(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) if phase == 'train': loss.backward() optimizer.step() epoch_loss += loss.item() * inputs.size(0) epoch_corrects += torch.sum(preds == labels.data) epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset) epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset) print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))2.11、主函数if __name__ == '__main__': num_epochs = 2 train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)

2.12、保存和加载网络模型 # 保存和读取训练完毕的网络 # save_path = './weights_fine_tuning.pth' # torch.save(net.state_dict(), save_path) # 加载训练好的网络参数 load_path = './weights_fine_tuning.pth' load_weights = torch.load(load_path) net.load_state_dict(load_weights) print(net)
本文链接地址:https://www.jiuchutong.com/zhishi/300555.html 转载请保留说明!

上一篇:使用el-upload组件实现递归多文件上传(elementui的upload组件详解)

下一篇:vue3 响应式 API 之 ref(vue3 响应式ui框架)

  • 教你怎样给你的微博营销化化妆?(如何给一个)

    教你怎样给你的微博营销化化妆?(如何给一个)

  • 如何关闭qq下拉功能(如何关闭qq下拉抢红包)

    如何关闭qq下拉功能(如何关闭qq下拉抢红包)

  • 苹果提示充电口有液体(苹果提示充电口有水)

    苹果提示充电口有液体(苹果提示充电口有水)

  • 蓝屏代码0x00000050(蓝屏代码0x0000005无限重启)

    蓝屏代码0x00000050(蓝屏代码0x0000005无限重启)

  • 钉钉直播可以放视频吗(钉钉直播可以放大吗)

    钉钉直播可以放视频吗(钉钉直播可以放大吗)

  • 水星wifi管理员密码是什么(水星wifi管理员密码忘了怎么重新设置)

    水星wifi管理员密码是什么(水星wifi管理员密码忘了怎么重新设置)

  • ipad11pro尺寸(ipad11pro尺寸多大)

    ipad11pro尺寸(ipad11pro尺寸多大)

  • 小米哪个型号是旗舰机(小米哪个型号是直板屏手机)

    小米哪个型号是旗舰机(小米哪个型号是直板屏手机)

  • 手机有什么办法可以省电(打印机不能连手机有什么办法)

    手机有什么办法可以省电(打印机不能连手机有什么办法)

  • 网络适配器多路传送器协议打钩吗(网络适配器多路传输协议勾选无效)

    网络适配器多路传送器协议打钩吗(网络适配器多路传输协议勾选无效)

  • 苹果小圆圈怎么设置(苹果小圆圈怎么搞出来)

    苹果小圆圈怎么设置(苹果小圆圈怎么搞出来)

  • 华为hicar什么时候上线(华为HiCar什么时候能用)

    华为hicar什么时候上线(华为HiCar什么时候能用)

  • iphone11有几个摄像头(apple11几个摄像头)

    iphone11有几个摄像头(apple11几个摄像头)

  • CAD将画的图虚拟打印到文件(cad虚实变换)

    CAD将画的图虚拟打印到文件(cad虚实变换)

  • 小爱同学怎么改换姓名(小爱同学怎么改密码)

    小爱同学怎么改换姓名(小爱同学怎么改密码)

  • 苹果怎么看照片大小(苹果怎么看照片分辨率是多少dpi)

    苹果怎么看照片大小(苹果怎么看照片分辨率是多少dpi)

  • 腾讯hdr臻彩视界什么意思(腾讯hdr臻彩视界和蓝光哪个好)

    腾讯hdr臻彩视界什么意思(腾讯hdr臻彩视界和蓝光哪个好)

  • 小米9怎么设置开机动画(小米9怎么设置返回键)

    小米9怎么设置开机动画(小米9怎么设置返回键)

  • QQ音乐设置面板在哪里(qq音乐系统设置)

    QQ音乐设置面板在哪里(qq音乐系统设置)

  • 华为除了手机还有什么业务(华为除了手机还生产什么)

    华为除了手机还有什么业务(华为除了手机还生产什么)

  • 为什么最右会停止运行(最右怎么回事)

    为什么最右会停止运行(最右怎么回事)

  • 苹果XR怎么放两张卡(苹果xr怎么放两张卡槽)

    苹果XR怎么放两张卡(苹果xr怎么放两张卡槽)

  • 在Mac下如何安装Win10有哪几种方法(macbook如何安装)

    在Mac下如何安装Win10有哪几种方法(macbook如何安装)

  • 【javaScript】学完js基础,顺便把js高级语法学了(尚硅谷视频学习笔记)(javascript学什么内容)

    【javaScript】学完js基础,顺便把js高级语法学了(尚硅谷视频学习笔记)(javascript学什么内容)

  • 小规模纳税人月销售额超过15万
  • 增值税税率改变后原项目的新增单价按哪个税率
  • 划转税务的非税发票
  • 委托代理合同后果的承担
  • 纳税申报意思
  • 西部大开发政策2020到期
  • 不动产经营租赁属于现代服务吗
  • 洗衣店每个月水电费多少钱
  • 慰问金怎么入账科目
  • 如果没有预缴就开票会怎样?
  • 消费满额赠礼
  • 完税证明已开回怎么处理
  • 个人退回公积金怎么操作
  • 交强险保单被保险人写谁都行?
  • windows10开机如何换帐号
  • 建筑工程账务处理是在哪个阶段
  • 单位社保部分会扣吗
  • 激进型和保守型筹资组合怎么判断
  • 购进材料用于在建工程进项税
  • 累计摊销在资产里怎么算
  • 发票金额大于报销金额可以吗
  • php auth_http类库进行身份效验
  • 工程暂估收入入账的会计分录
  • 递延性负债
  • 长期借款和短期借款会计分录的区别
  • 科研的成果形式
  • ecshop功能
  • 增值税如何在报表里填写
  • 固定资产如何抵成本
  • 资产负债表项目填列的依据是
  • hypergraph learning
  • 基于Perclos&改进YOLOv7的疲劳驾驶DMS检测系统(源码&教程)
  • 大学生创新创业大赛官网
  • 新增总产值
  • 购销合同印花税最新政策2023
  • 应计入财务费用的科目是
  • ps魔棒工具选择图像时在容差数值较大的是
  • dedecms 授权
  • 终止合约取得的合约
  • 文化传媒有限公司英文
  • 弥补以前年度亏损怎么算
  • 会计人员未参加继续教育
  • 残保金怎么计提和缴纳
  • 减免的企业所得税怎么做账
  • 转账进公户
  • 核定征收适用于什么税率
  • 用友t3采购订单怎么录入
  • 行政事业单位过节费发放规定
  • 差额征税的账务处理教学视频
  • 工伤保险交了就可以报销吗
  • 新开办的企业怎么做账
  • 一般纳税人购买汽车会计分录
  • 利息收入记借方还是贷方
  • 企业的免税收入范围
  • 上期留抵本期抵扣怎么做分录
  • 对公账户怎么打印
  • 资产质量的相对性举例说明
  • window怎么操作
  • xp系统设置锁屏
  • windows允许多用户登录
  • mac闹钟app
  • win7应用程序无法正常启动
  • linux epub阅读器
  • win8系统安装条件
  • ie 无法打开
  • opengl入门教程(精)
  • linux awk命令使用实例
  • unity开发用什么电脑比较好
  • node_modules复制
  • node一次执行多个文件
  • nodejs全局异常监听
  • 浅谈python中的实例方法、类方法和静态方法
  • python自动化源码
  • javascript 自定义类
  • 土地增值税按什么价格
  • 内蒙古国家税务局网上电子税务局官网
  • 装卸搬运费属于
  • 互城通怎么用微信充值
  • 残疾人有车能否坐公交车
  • 房产折旧怎么算
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号