位置: IT常识 - 正文

使用Pytorch实现深度学习的主要流程(pytorch技巧)

编辑:rootadmin
使用Pytorch实现深度学习的主要流程 一、使用Pytorch实现深度学习的主要流程

推荐整理分享使用Pytorch实现深度学习的主要流程(pytorch技巧),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch怎么用,pytorch例程,pytorch 简单例子,pytorch教程,pytorch基本操作,pytorch例程,pytorch怎么用,pytorch怎么用,内容如对您有帮助,希望把文章链接给更多的朋友!

使用Pytorch进行深度学习的实现流程主要包含如下几个部分: 1、预处理、后处理并确认网络的输入和输出 2、创建Dataset 3、创建DataLoader 4、创建网络模型 5、定义正向传播函数(forward) 6、定义损失函数 7、设置最优化算法 8、进行训练和验证 9、达到标准保存模型 10、加载模型使用测试数据进行测试 在使用Pytorch实现深度学习的整体流程中,首先需要对准备实现的深度学习算法从整体上进行把握。即对预处理以及后处理,网络模型的输入和输出进行确认。 创建Dataset就是将输入数据和与其对应的标签组成配对数据进行保存的类。这里将用于处理数据的预处理类的实例指定到Dataset中,并设定其在从文件中读取对象数据时,自动对输入数据进行预处理。 接下来是创建Dataloader,DataLoader是用来设定从Dataset中读取数据的具体方法的类。在深度学习中,通常都是采用小批次学习的方式,将多个数据同时从Dataset中取出,并传递给神经网络进行学习训练,DataLoader就是负责简化从Dataset中取出小批次数据这一操作的类,需要分别创建好用于训练数据以及验证数据的Dataloader。 接下来是创建网络模型,创建网络模型共有三种方式,第一种从零开始实现整个网络模型;第二种是直接载入已经训练好的网络模型,第三种是以现有训练好的网络模型为基础,将其改造为自己需要的模型。在深度学习实际应用中,大多数情况是以训练好的网络模型为基础,将其改造成符合自身需要的模型。 在成功创建网络模型之后,就需要定义网络模型的正向传播函数,forward函数,接下来要做的就是定义用于将误差值进行反向传播的损失函数,对于解决不同的任务会设置不同的损失函数。 下一步就是设定在对网络模型的连接参数进行训练时使用的优化算法,通过误差的反向传播,可以对连接参数的误差对应的梯度进行计算。优化算法就是指如何根据这一梯度值计算出连接参数的修正量具体算法。常用的优化算法有 Adam、SGD。 通过上述步骤,就完成了进行深度学习所需要的所有设置,接下来进行实际的学习和验证操作。通常以epoch为单位,对训练数据的性能和验证数据的性能进行比较,如果验证数据的性能停止提升,之后的训练都会陷入到过拟合状态,因此需要及时停止训练,提前终止网络学习的方法又称为early stopping。 学习完成之后保存我们训练得到的最优模型,之后加载模型进行推理。

二、代码实战2.1、软件包的导入以及初始设置# 实现代码的初始设置import globimport os.path as ospimport randomimport numpy as npimport jsonfrom PIL import Imagefrom tqdm import tqdmimport matplotlib.pyplot as pltimport torchimport torch.nn as nnimport torch.optim as optimimport torch.utils.data as dataimport torchvisionfrom torchvision import models, transformstorch.manual_seed(1234)np.random.seed(1234)random.seed(1234)2.2、创建Dataset# 创建DataSetclass ImageTransform(): def __init__(self, resize, mean, std): self.data_transform = { 'train': transforms.Compose([ transforms.RandomResizedCrop(resize, scale=(0.5, 1.0)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std) ]), 'val': transforms.Compose([ transforms.Resize(resize), transforms.CenterCrop(resize), transforms.ToTensor(), transforms.Normalize(mean, std) ]) } def __call__(self, img, phase='train'): return self.data_transform[phase](img)2.3、查看图像预处理前后的对比# 1.读取图像image_file_path = './data/goldenretriever-3724972_1280.jpg'img = Image.open(image_file_path)# 2.显示原图# img.show()plt.imshow(img)plt.show()# 3.预处理size = 224mean = (0.485, 0.456, 0.406)std = (0.229, 0.224, 0.225)transform = ImageTransform(size, mean, std)img_transformed = transform(img, phase='train')img_transformed = img_transformed.numpy().transpose((1, 2, 0))img_transformed = np.clip(img_transformed, 0, 1)plt.imshow(img_transformed)plt.show()

2.4、创建用于保存图片文件路径的列表变量# 用于保存蚂蚁和蜜蜂的图片文件路径列表变量def make_data_path_list(phase='train'): root_path = './data/hymenoptera_data/' target_path = osp.join(root_path + phase + '/**/*.jpg') # print(target_path) # ./data/hymenoptera_data/train/**/*.jpg # ./data/hymenoptera_data/val/**/*.jpg path_list = [] for path in glob.glob(target_path): path_list.append(path) return path_listtrain_list = make_data_path_list(phase='train')val_list = make_data_path_list(phase='val')print(train_list[:5])使用Pytorch实现深度学习的主要流程(pytorch技巧)

[‘./data/hymenoptera_data/train/bees/2638074627_6b3ae746a0.jpg’, ‘./data/hymenoptera_data/train/bees/507288830_f46e8d4cb2.jpg’, ‘./data/hymenoptera_data/train/bees/2405441001_b06c36fa72.jpg’, ‘./data/hymenoptera_data/train/bees/2962405283_22718d9617.jpg’, ‘./data/hymenoptera_data/train/bees/446296270_d9e8b93ecf.jpg’]

2.5、创建图片组成的Dataset# 构建datasetclass HymenopteraDataset(data.Dataset): def __init__(self, file_list, transform=None, phase='train'): self.file_list = file_list self.transform = transform self.phase = phase def __len__(self): return len(self.file_list) def __getitem__(self, index): img_path = self.file_list[index] img = Image.open(img_path) img_transformed = self.transform(img, self.phase) if self.phase == 'train': label = img_path[30:34] elif self.phase == 'val': label = img_path[28: 32] if label == 'ants': label = 0 elif label == 'bees': label = 1 return img_transformed, labeltrain_dataset = HymenopteraDataset( file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')val_dataset = HymenopteraDataset( file_list=val_list, transform=ImageTransform(size, mean, std), phase='val')# 查看图片和标签index = 0print(train_dataset.__getitem__(index)[0].size())print(train_dataset.__getitem__(index)[1])

torch.Size([3, 224, 224]) 1

2.6、创建Dataloader# 创建DataLoaderbatch_size = 32train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True)val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False)dataloaders_dict = {'train': train_loader, 'val': val_loader}batch_iterator = iter(dataloaders_dict['train'])inputs, labels = next(batch_iterator)print(inputs.size())print(labels)

torch.Size([32, 3, 224, 224]) tensor([0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1])

2.7、创建网络模型# 创建网络模型use_pretrained = Truenet = models.vgg16(pretrained=use_pretrained)print(net)net.classifier[6] = nn.Linear(in_features=4096, out_features=2)print(net)net.train()print('网络设置完毕: 载入已经学习完毕的权重,并设置为训练模式')

VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=1000, bias=True) ) ) VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=2, bias=True) ) ) 网络设置完毕: 载入已经学习完毕的权重,并设置为训练模式

2.8、定义损失函数# 定义损失函数criterion = nn.CrossEntropyLoss()2.9、设定最优化算法# 设定最优化算法params_to_update = []# print(net.named_parameters())update_param_names = ['classifier.6.weight', 'classifier.6.bias']for name, param in net.named_parameters(): # print(name) # print(param) if name in update_param_names: param.requires_grad = True params_to_update.append(param) print(name) else: param.requires_grad = Falseprint('=========================')print(params_to_update)# 设置最优化算法optimizer = optim.SGD(params=params_to_update, lr=0.001, momentum=0.9)print(optimizer)classifier.6.weight classifier.6.bias

[Parameter containing: tensor([[-0.0048, 0.0072, -0.0081, …, 0.0003, -0.0040, 0.0048], [ 0.0051, 0.0072, -0.0154, …, 0.0054, 0.0152, 0.0083]], requires_grad=True), Parameter containing: tensor([ 0.0108, -0.0054], requires_grad=True)] SGD ( Parameter Group 0 dampening: 0 lr: 0.001 momentum: 0.9 nesterov: False weight_decay: 0 )

2.10、训练和验证# 学习和验证的实行def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs): for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch + 1, num_epochs)) print('--------------------') for phase in ['train', 'val']: if phase == 'train': net.train() else: net.eval() epoch_loss = 0.0 epoch_corrects = 0 if epoch == 0 and phase == 'train': continue for inputs, labels in tqdm(dataloaders_dict[phase]): optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = net(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) if phase == 'train': loss.backward() optimizer.step() epoch_loss += loss.item() * inputs.size(0) epoch_corrects += torch.sum(preds == labels.data) epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset) epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset) print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))2.11、主函数if __name__ == '__main__': num_epochs = 2 train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)

2.12、保存和加载网络模型 # 保存和读取训练完毕的网络 # save_path = './weights_fine_tuning.pth' # torch.save(net.state_dict(), save_path) # 加载训练好的网络参数 load_path = './weights_fine_tuning.pth' load_weights = torch.load(load_path) net.load_state_dict(load_weights) print(net)
本文链接地址:https://www.jiuchutong.com/zhishi/300555.html 转载请保留说明!

上一篇:使用el-upload组件实现递归多文件上传(elementui的upload组件详解)

下一篇:vue3 响应式 API 之 ref(vue3 响应式ui框架)

  • 怎样创新网络推广方式(网络推广新方法)

    怎样创新网络推广方式(网络推广新方法)

  • 小米手机如何看拦截短信和来电(小米手机如何看使用了多长时间)

    小米手机如何看拦截短信和来电(小米手机如何看使用了多长时间)

  • 高德地图电子狗功能在哪里(高德地图电子狗没声音)

    高德地图电子狗功能在哪里(高德地图电子狗没声音)

  • 天翼云盘家庭云视频别人能看到吗(天翼云盘家庭云怎么删除成员)

    天翼云盘家庭云视频别人能看到吗(天翼云盘家庭云怎么删除成员)

  • 为什么电脑没声音(为什么电脑没声音 显示切换输出设备)

    为什么电脑没声音(为什么电脑没声音 显示切换输出设备)

  • 苹果生态是什么意思(苹果生态是啥)

    苹果生态是什么意思(苹果生态是啥)

  • 抖音账号权重有6个级别(抖音账号权重有多少)

    抖音账号权重有6个级别(抖音账号权重有多少)

  • 9x支持快充吗(9x支持快速充电吗)

    9x支持快充吗(9x支持快速充电吗)

  • vivox50微信视频带美颜吗(vivox50微信视频美颜在哪里)

    vivox50微信视频带美颜吗(vivox50微信视频美颜在哪里)

  • 钉钉浮窗看抖音算时间吗(钉钉悬浮窗看抖音会不会计入时长)

    钉钉浮窗看抖音算时间吗(钉钉悬浮窗看抖音会不会计入时长)

  • qq头像模糊怎么解决(qq头像模糊处理方法)

    qq头像模糊怎么解决(qq头像模糊处理方法)

  • 华为畅连视频怎么收费(华为畅连视频花钱吗)

    华为畅连视频怎么收费(华为畅连视频花钱吗)

  • iphone5是什么处理器(苹果5c是什么)

    iphone5是什么处理器(苹果5c是什么)

  • 360N6Pro可以双卡双待吗(360手机双卡)

    360N6Pro可以双卡双待吗(360手机双卡)

  • 苹果11屏幕发黄正常吗(苹果11屏幕发黄怎么解决)

    苹果11屏幕发黄正常吗(苹果11屏幕发黄怎么解决)

  • 华为平板总是黑屏怎么回事(华为平板总是黑屏怎么办)

    华为平板总是黑屏怎么回事(华为平板总是黑屏怎么办)

  • 为什么来电不显示姓名(为什么来电不显示名字)

    为什么来电不显示姓名(为什么来电不显示名字)

  • 字符间距怎么设置(字符间距怎么设置缩放)

    字符间距怎么设置(字符间距怎么设置缩放)

  • ipad如何取消订阅(ipad如何取消)

    ipad如何取消订阅(ipad如何取消)

  • 魅族怎么开多任务(魅族怎么开多任务运存占用)

    魅族怎么开多任务(魅族怎么开多任务运存占用)

  • 苹果x是2k分辨率吗(苹果x是2k分辨率多少)

    苹果x是2k分辨率吗(苹果x是2k分辨率多少)

  • vjvj是什么牌子手机(jvjow是什么牌子)

    vjvj是什么牌子手机(jvjow是什么牌子)

  • xmax尺寸多大(iphonexsmax长多少厘米)

    xmax尺寸多大(iphonexsmax长多少厘米)

  • 苹果手机电池标志怎么变黄了(苹果手机电池标志变白色了)

    苹果手机电池标志怎么变黄了(苹果手机电池标志变白色了)

  • word打开很慢如何办(word打开慢是怎么回事)

    word打开很慢如何办(word打开慢是怎么回事)

  • Win11正式版值得升吗?Win11正式版和Win10区别对比介绍(windows11正式版好用吗)

    Win11正式版值得升吗?Win11正式版和Win10区别对比介绍(windows11正式版好用吗)

  • 取得的证券投资业绩
  • 社保与个税有关系么
  • 公司申请破产后股东需要还债吗
  • 联合体项目工程款如何拨付
  • 公司现金支票取钱需要带什么资料
  • 小规模纳税人涉税风险
  • 公转法人交税
  • 国有企业改制资产评估增值税收规
  • 专项用途财政资金纳税调整规则
  • 库存产品亏本销售账务处理
  • 预计销售退回的钱怎么算
  • 企业受托开发软件是什么
  • 非金融机构借款计入什么科目
  • 境外派遣员工境外所得税是什么时候申报?
  • 个体工商户税务登记证需要什么资料
  • 发票已认证对方起诉有效吗
  • 出租不动产什么时候交税
  • 红字增值税专用发票信息表怎么填
  • 企业微信收款的钱怎么提取出来
  • 联想thinkpad安装win7方法
  • 事业单位财政拨款是编制吗
  • 美团佣金收费标准结构图
  • 在当前目录下打开cmd
  • 如何在qq好友旁边打字
  • 如何关闭win11系统
  • 结转本年利润的分录怎么写
  • 委托外部加工材料支付加工费计入
  • 同一控制下的企业合并,合并方在企业合并中取得的资产
  • 应税货物销售额是什么意思
  • 按税收的计税依据为标准税收分为
  • AI:ModelScope(一站式开源的模型即服务共享平台)的简介、安装、使用方法之详细攻略
  • 发放职工薪酬账务怎么做
  • 农业经营许可证范围
  • 浅谈双减背景下的高效课堂
  • 先付款后收到发票怎么入账
  • 基于html的旅游网站设计源代码
  • erp面试题目100及最佳答案
  • 投资性房地产在资产负债表哪个科目
  • 员工办理健康证需要什么材料
  • 开票品名不一样有什么关系
  • sql有数据保护功能
  • Shading-JDBC、ShadingSphere、ShardingProxy 使用详解
  • 租车出差差旅费标准
  • 经营一家淘宝店铺,自然就应该做好
  • 资本公积和盈余公积都与利润有关
  • 长期待摊费用的摊销期限应该是
  • 资金账簿印花税减半政策
  • 职工薪酬包括哪几类
  • 一般纳税人交增值税的账务处理
  • 无形资产未确认融资费用例题
  • 借递延所得税资产贷递延所得税费用
  • 小规模公司购买汽车会计分录
  • 应付账款收不回发票该如何调整
  • 参加新冠疫情防控工作感悟 医务人员
  • 有支出没有发票应怎么整改
  • 卖二手车怎么做账务处理
  • 汽车行业的保险返点怎么算
  • 外币存款利息是不是外币
  • 天然气管道安装费多少钱一米
  • 出口收到货款怎么做账
  • 银行账户是不是卡号
  • 会务费发票开普票还是专票
  • 原材料入库单应根据采购订单还是到货数量
  • 预收账款怎么做账
  • virtualbox?
  • win8关机后自动重启怎么办
  • win xp怎么样
  • linux中文件系统
  • 文件夹底部显示
  • windows7的开机启动项在哪里
  • linux 命令连接
  • Error: String types not allowed (at 'layout_gravity' with value 'bottom/center_horizontal').
  • 微信小程序中显示app.json在项目根目录未找到怎么回事
  • vue-cal
  • mac安装nodejs的权限问题
  • javascriptcsdn
  • python的例子
  • 叉车需要手续吗
  • 地税局和税务局一样吗
  • 河北电子税务局社保缴费流程
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设