位置: IT常识 - 正文

使用Pytorch实现深度学习的主要流程(pytorch技巧)

编辑:rootadmin
使用Pytorch实现深度学习的主要流程 一、使用Pytorch实现深度学习的主要流程

推荐整理分享使用Pytorch实现深度学习的主要流程(pytorch技巧),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch怎么用,pytorch例程,pytorch 简单例子,pytorch教程,pytorch基本操作,pytorch例程,pytorch怎么用,pytorch怎么用,内容如对您有帮助,希望把文章链接给更多的朋友!

使用Pytorch进行深度学习的实现流程主要包含如下几个部分: 1、预处理、后处理并确认网络的输入和输出 2、创建Dataset 3、创建DataLoader 4、创建网络模型 5、定义正向传播函数(forward) 6、定义损失函数 7、设置最优化算法 8、进行训练和验证 9、达到标准保存模型 10、加载模型使用测试数据进行测试 在使用Pytorch实现深度学习的整体流程中,首先需要对准备实现的深度学习算法从整体上进行把握。即对预处理以及后处理,网络模型的输入和输出进行确认。 创建Dataset就是将输入数据和与其对应的标签组成配对数据进行保存的类。这里将用于处理数据的预处理类的实例指定到Dataset中,并设定其在从文件中读取对象数据时,自动对输入数据进行预处理。 接下来是创建Dataloader,DataLoader是用来设定从Dataset中读取数据的具体方法的类。在深度学习中,通常都是采用小批次学习的方式,将多个数据同时从Dataset中取出,并传递给神经网络进行学习训练,DataLoader就是负责简化从Dataset中取出小批次数据这一操作的类,需要分别创建好用于训练数据以及验证数据的Dataloader。 接下来是创建网络模型,创建网络模型共有三种方式,第一种从零开始实现整个网络模型;第二种是直接载入已经训练好的网络模型,第三种是以现有训练好的网络模型为基础,将其改造为自己需要的模型。在深度学习实际应用中,大多数情况是以训练好的网络模型为基础,将其改造成符合自身需要的模型。 在成功创建网络模型之后,就需要定义网络模型的正向传播函数,forward函数,接下来要做的就是定义用于将误差值进行反向传播的损失函数,对于解决不同的任务会设置不同的损失函数。 下一步就是设定在对网络模型的连接参数进行训练时使用的优化算法,通过误差的反向传播,可以对连接参数的误差对应的梯度进行计算。优化算法就是指如何根据这一梯度值计算出连接参数的修正量具体算法。常用的优化算法有 Adam、SGD。 通过上述步骤,就完成了进行深度学习所需要的所有设置,接下来进行实际的学习和验证操作。通常以epoch为单位,对训练数据的性能和验证数据的性能进行比较,如果验证数据的性能停止提升,之后的训练都会陷入到过拟合状态,因此需要及时停止训练,提前终止网络学习的方法又称为early stopping。 学习完成之后保存我们训练得到的最优模型,之后加载模型进行推理。

二、代码实战2.1、软件包的导入以及初始设置# 实现代码的初始设置import globimport os.path as ospimport randomimport numpy as npimport jsonfrom PIL import Imagefrom tqdm import tqdmimport matplotlib.pyplot as pltimport torchimport torch.nn as nnimport torch.optim as optimimport torch.utils.data as dataimport torchvisionfrom torchvision import models, transformstorch.manual_seed(1234)np.random.seed(1234)random.seed(1234)2.2、创建Dataset# 创建DataSetclass ImageTransform(): def __init__(self, resize, mean, std): self.data_transform = { 'train': transforms.Compose([ transforms.RandomResizedCrop(resize, scale=(0.5, 1.0)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std) ]), 'val': transforms.Compose([ transforms.Resize(resize), transforms.CenterCrop(resize), transforms.ToTensor(), transforms.Normalize(mean, std) ]) } def __call__(self, img, phase='train'): return self.data_transform[phase](img)2.3、查看图像预处理前后的对比# 1.读取图像image_file_path = './data/goldenretriever-3724972_1280.jpg'img = Image.open(image_file_path)# 2.显示原图# img.show()plt.imshow(img)plt.show()# 3.预处理size = 224mean = (0.485, 0.456, 0.406)std = (0.229, 0.224, 0.225)transform = ImageTransform(size, mean, std)img_transformed = transform(img, phase='train')img_transformed = img_transformed.numpy().transpose((1, 2, 0))img_transformed = np.clip(img_transformed, 0, 1)plt.imshow(img_transformed)plt.show()

2.4、创建用于保存图片文件路径的列表变量# 用于保存蚂蚁和蜜蜂的图片文件路径列表变量def make_data_path_list(phase='train'): root_path = './data/hymenoptera_data/' target_path = osp.join(root_path + phase + '/**/*.jpg') # print(target_path) # ./data/hymenoptera_data/train/**/*.jpg # ./data/hymenoptera_data/val/**/*.jpg path_list = [] for path in glob.glob(target_path): path_list.append(path) return path_listtrain_list = make_data_path_list(phase='train')val_list = make_data_path_list(phase='val')print(train_list[:5])使用Pytorch实现深度学习的主要流程(pytorch技巧)

[‘./data/hymenoptera_data/train/bees/2638074627_6b3ae746a0.jpg’, ‘./data/hymenoptera_data/train/bees/507288830_f46e8d4cb2.jpg’, ‘./data/hymenoptera_data/train/bees/2405441001_b06c36fa72.jpg’, ‘./data/hymenoptera_data/train/bees/2962405283_22718d9617.jpg’, ‘./data/hymenoptera_data/train/bees/446296270_d9e8b93ecf.jpg’]

2.5、创建图片组成的Dataset# 构建datasetclass HymenopteraDataset(data.Dataset): def __init__(self, file_list, transform=None, phase='train'): self.file_list = file_list self.transform = transform self.phase = phase def __len__(self): return len(self.file_list) def __getitem__(self, index): img_path = self.file_list[index] img = Image.open(img_path) img_transformed = self.transform(img, self.phase) if self.phase == 'train': label = img_path[30:34] elif self.phase == 'val': label = img_path[28: 32] if label == 'ants': label = 0 elif label == 'bees': label = 1 return img_transformed, labeltrain_dataset = HymenopteraDataset( file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')val_dataset = HymenopteraDataset( file_list=val_list, transform=ImageTransform(size, mean, std), phase='val')# 查看图片和标签index = 0print(train_dataset.__getitem__(index)[0].size())print(train_dataset.__getitem__(index)[1])

torch.Size([3, 224, 224]) 1

2.6、创建Dataloader# 创建DataLoaderbatch_size = 32train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True)val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False)dataloaders_dict = {'train': train_loader, 'val': val_loader}batch_iterator = iter(dataloaders_dict['train'])inputs, labels = next(batch_iterator)print(inputs.size())print(labels)

torch.Size([32, 3, 224, 224]) tensor([0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1])

2.7、创建网络模型# 创建网络模型use_pretrained = Truenet = models.vgg16(pretrained=use_pretrained)print(net)net.classifier[6] = nn.Linear(in_features=4096, out_features=2)print(net)net.train()print('网络设置完毕: 载入已经学习完毕的权重,并设置为训练模式')

VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=1000, bias=True) ) ) VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=2, bias=True) ) ) 网络设置完毕: 载入已经学习完毕的权重,并设置为训练模式

2.8、定义损失函数# 定义损失函数criterion = nn.CrossEntropyLoss()2.9、设定最优化算法# 设定最优化算法params_to_update = []# print(net.named_parameters())update_param_names = ['classifier.6.weight', 'classifier.6.bias']for name, param in net.named_parameters(): # print(name) # print(param) if name in update_param_names: param.requires_grad = True params_to_update.append(param) print(name) else: param.requires_grad = Falseprint('=========================')print(params_to_update)# 设置最优化算法optimizer = optim.SGD(params=params_to_update, lr=0.001, momentum=0.9)print(optimizer)classifier.6.weight classifier.6.bias

[Parameter containing: tensor([[-0.0048, 0.0072, -0.0081, …, 0.0003, -0.0040, 0.0048], [ 0.0051, 0.0072, -0.0154, …, 0.0054, 0.0152, 0.0083]], requires_grad=True), Parameter containing: tensor([ 0.0108, -0.0054], requires_grad=True)] SGD ( Parameter Group 0 dampening: 0 lr: 0.001 momentum: 0.9 nesterov: False weight_decay: 0 )

2.10、训练和验证# 学习和验证的实行def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs): for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch + 1, num_epochs)) print('--------------------') for phase in ['train', 'val']: if phase == 'train': net.train() else: net.eval() epoch_loss = 0.0 epoch_corrects = 0 if epoch == 0 and phase == 'train': continue for inputs, labels in tqdm(dataloaders_dict[phase]): optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = net(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) if phase == 'train': loss.backward() optimizer.step() epoch_loss += loss.item() * inputs.size(0) epoch_corrects += torch.sum(preds == labels.data) epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset) epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset) print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))2.11、主函数if __name__ == '__main__': num_epochs = 2 train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)

2.12、保存和加载网络模型 # 保存和读取训练完毕的网络 # save_path = './weights_fine_tuning.pth' # torch.save(net.state_dict(), save_path) # 加载训练好的网络参数 load_path = './weights_fine_tuning.pth' load_weights = torch.load(load_path) net.load_state_dict(load_weights) print(net)
本文链接地址:https://www.jiuchutong.com/zhishi/300555.html 转载请保留说明!

上一篇:使用el-upload组件实现递归多文件上传(elementui的upload组件详解)

下一篇:vue3 响应式 API 之 ref(vue3 响应式ui框架)

  • 所得税汇算清缴招待费扣除标准
  • 开外经证需要预缴税几个点
  • 一个人可以做多少家公司法人
  • 年金是否一定是每年发生一次
  • 外经证错了已经交了税怎么办
  • 先报税还是先清卡反写
  • 未开票收入下月开票怎么报税
  • 长期股权投资审计说明
  • 成本类科目有哪些口诀
  • 工业企业成本会计核算的对象是什么
  • 土地使用税什么意思
  • 增值税普通发票和普通发票的区别怎么交税
  • 销售返利增值税按哪个税率
  • 私车公用可以企业所得税税前扣除吗
  • 没有ca证书怎么连接wifi加密设备
  • 资产损失申报扣除
  • 设备维修增值税
  • 工人受伤医药费计入什么科目
  • 某企业原材料采用实际成本核算,2019年6月
  • 企业停产或停业期间的费用包括
  • 职工食堂的费用怎么入账
  • 交叉持股的合并财务报表
  • 无形资产摊销方法应当反映其经济利益
  • 员工离职再入职要重新签订合同吗
  • 所得税费用什么时候结转
  • 应交税费的会计处理2018
  • 小规模减免的增值税怎么记账
  • backupnotify.exe是什么文件的进程 backupnotify进程安全吗
  • 如何给宽带加速使用
  • thinkphp 模型
  • linux杀死服务
  • 以前年度应付账款做到制造费用如何改账
  • icons是什么文件夹
  • 前端页面出现乱码
  • 最强笔记本2021
  • .hpp是什么文件
  • 待摊费用每月怎么摊
  • php read
  • PHP:imagefilledellipse()的用法_GD库图像处理函数
  • uml中的顺序图由什么组成
  • php编写用户注册界面
  • php中的pdo
  • 发票抵扣联能报销吗
  • vm网络不可达
  • 一般纳税人注销税务流程
  • 如何登记现金明细账
  • 进项税抵增值税
  • python里面init
  • 结余资金包括结转资金吗?
  • 批发和零售业行业代码
  • 技术转让和技术开发区别
  • 科目余额表和资产负债表的期末余额不一样,怎么办
  • 修改数据库为多个数据
  • 企业对公帐户怎么转出私人帐户
  • 收到的技术服务费计入什么科目
  • 营业税改增值税有什么好处
  • 制造费用科目一定无余额
  • 化妆品的成本利润率
  • 政府代建工程
  • 为什么一般纳税人可以选择简易计税
  • 暂估入库一直没有发票
  • 用SQL脚本读取Excel中的sheet数量及名称的方法代码
  • sql如何实现
  • solaris安装教程
  • 怎么进入win7系统
  • ubuntu 16:9
  • linux的run目录放什么文件
  • pe硬盘安装win7系统教程
  • GLWallpaperService分析一
  • html5画布五角星
  • opengl自学
  • unity cpu优化
  • javascript如何
  • javascript高级程序设计第五版 pdf下载
  • javascript从入门到放弃
  • Jquery中巧用Ajax的beforeSend方法
  • 成都市个人房屋出租税费
  • 单位购买房产作废怎么办
  • 我的宁夏灵活就业缴费失败
  • 定额发票属于什么类型
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设