位置: IT常识 - 正文

模型训练步骤(模型训练的过程是什么过程)

编辑:rootadmin
模型训练步骤

推荐整理分享模型训练步骤(模型训练的过程是什么过程),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:模型训练步骤包括,模型训练的步骤,模型训练的步骤,模型训练的步骤,模型训练步骤包括,模型训练步骤有哪些,模型训练步骤包括,模型训练步骤包括验证,内容如对您有帮助,希望把文章链接给更多的朋友!

1.在model.py搭建神经网络。

# 搭建神经网络 10分类网络。import torchfrom torch import nnclass net(nn.Module): def __init__(self): super(net, self).__init__() self.model = nn.Sequential( # 卷积 nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2), # 最大池化 nn.MaxPool2d(kernel_size=2), # 卷积 nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2), # 最大池化 nn.MaxPool2d(kernel_size=2), # 卷积 nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2), # 最大池化 nn.MaxPool2d(kernel_size=2), # 展平 nn.Flatten(), # 线性层 nn.Linear(in_features=64 * 4 * 4, out_features=64), nn.Linear(in_features=64, out_features=10) ) def forward(self, x): return self.model(x)

2.验证搭建网络的正确性

if __name__ == '__main__': # 测试网络的验证正确性 tudui = Tudui() input = torch.ones((64,3,32,32)) # batch_size=64(代表64张图片),3通道,32x32 output = tudui(input) print(output.shape)

结果是

torch.Size([64,10])

返回64行数据,每一行10个数据,代表每一张图片的概率。

3.在train.py下

①准备数据集,一个训练数据集,一个测试数据集。因为CIFAR10数据集是PIL,要转为tensor数据类型。

train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=torchvision.transforms.ToTensor(), download=True)test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=torchvision.transforms.ToTensor(), download=True)

②加载数据集。利用DataLoader加载数据集。

train_dataloader = DataLoader(dataset=train_data, batch_size=64)test_dataloader = DataLoader(dataset=test_data, batch_size=64)

③创建网络模型

from model import * wang = net()

模型训练步骤(模型训练的过程是什么过程)

④创建损失函数

loss_fn = nn.CrossEntropyLoss()

⑤创建优化器

learning_rate = 0.01optimizer = torch.optim.SGD(params=wang.parameters(), lr=learning_rate)

⑥设置网络训练参数

# 设置训练网络的一些参数# 记录训练次数total_train_step = 0# 记录测试的次数total_test_step = 0# 训练的轮数epoch = 10

⑦开始训练

for i in range(epoch): print("----------第{}轮训练开始-----------".format(i+1)) # i从0-9 # 训练步骤开始 for data in train_dataloader: imgs,targets = data outputs = tudui(imgs) loss = loss_fn(outputs,targets) # 优化器优化模型 optimizer.zero_grad() # 首先要梯度清零 loss.backward() # 反向传播得到每一个参数节点的梯度 optimizer.step() # 对参数进行优化 total_train_step += 1 print("训练次数:{},loss:{}".format(total_train_step,loss.item()))

【补充:】

import torcha = torch.tensor(5)print(a)print(a.item())

输出:

tensor(5)

5.【测试】:看模型是否训练好。

每次训练完进行一轮测试,看测试集的损失或者正确率评估模型是否训练好。

测试过程模型不需要调优,利用现有的模型测试。

with torch.no_grad():

6.在上述代码继续编写

# 测试步骤开始 total_test_loss = 0 with torch.no_grad(): # 无梯度,不进行调优 for data in test_dataloader: imgs,targets = data outputs = tudui(imgs) loss = loss_fn(outputs,targets) # 该loss为部分数据在网络模型上的损失,为tensor数据类型 # 求整体测试数据集上的误差或正确率 total_test_loss = total_test_loss + loss.item() # loss为tensor数据类型,而total_test_loss为普通数字 print("整体测试集上的Loss:{}".format(total_test_loss))

7.跟TensorbBoard相结合

import torchvision.datasetsfrom torch.utils.tensorboard import SummaryWriterfrom model import *from torch import nnfrom torch.utils.data import DataLoader# 准备数据集,CIFAR10 数据集是PIL Image,要转换为tensor数据类型train_data = torchvision.datasets.CIFAR10(root="../data",train=True,transform=torchvision.transforms.ToTensor(),download=True)test_data = torchvision.datasets.CIFAR10(root="../data",train=False,transform=torchvision.transforms.ToTensor(),download=True)# 看一下训练数据集和测试数据集都有多少张(如何获得数据集的长度)train_data_size = len(train_data) # length 长度test_data_size = len(test_data)# 如果train_data_size=10,那么打印出的字符串为:训练数据集的长度为:10print("训练数据集的长度为:{}".format(train_data_size)) # 字符串格式化,把format中的变量替换{}print("测试数据集的长度为:{}".format(test_data_size))# 利用 DataLoader 来加载数据集train_dataloader = DataLoader(train_data,batch_size=64)test_dataloader = DataLoader(test_data,batch_size=64)# 创建网络模型tudui = Tudui()# 创建损失函数loss_fn = nn.CrossEntropyLoss() # 分类问题可以用交叉熵# 定义优化器learning_rate = 0.01 # 另一写法:1e-2,即1x 10^(-2)=0.01optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate) # SGD 随机梯度下降# 设置训练网络的一些参数total_train_step = 0 # 记录训练次数total_test_step = 0 # 记录测试次数epoch = 10 # 训练轮数# 添加tensorboardwriter = SummaryWriter("../logs_train")for i in range(epoch): print("----------第{}轮训练开始-----------".format(i+1)) # i从0-9 # 训练步骤开始 for data in train_dataloader: imgs,targets = data outputs = tudui(imgs) loss = loss_fn(outputs,targets) # 优化器优化模型 optimizer.zero_grad() # 首先要梯度清零 loss.backward() # 反向传播得到每一个参数节点的梯度 optimizer.step() # 对参数进行优化 total_train_step += 1 if total_train_step % 100 ==0: # 逢百才打印记录 print("训练次数:{},loss:{}".format(total_train_step,loss.item())) writer.add_scalar("train_loss",loss.item(),total_train_step) # 测试步骤开始 total_test_loss = 0 with torch.no_grad(): # 无梯度,不进行调优 for data in test_dataloader: imgs,targets = data outputs = tudui(imgs) loss = loss_fn(outputs,targets) # 该loss为部分数据在网络模型上的损失,为tensor数据类型 # 求整体测试数据集上的误差或正确率 total_test_loss = total_test_loss + loss.item() # loss为tensor数据类型,而total_test_loss为普通数字 print("整体测试集上的Loss:{}".format(total_test_loss)) writer.add_scalar("test_loss",total_test_loss,total_test_step) total_test_step += 1writer.close()

保存模型:

torch.save(tudui,"tudui_{}.pth".format(i)) # 每一轮保存一个结果 print("模型已保存")writer.close()

【代码优化,提升正确率】

# 求整体测试数据集上的误差或正确率 accuracy = (outputs.argmax(1) == targets).sum() # 1:横向比较,==:True或False,sum:计算True或False个数 total_accuracy = total_accuracy + accuracy print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size)) # 正确率为预测对的个数除以测试集长度 writer.add_scalar("test_accuracy",total_test_loss,total_test_step,total_test_step)

【完整代码】

import torchimport torchvision.datasetsfrom torch.utils.tensorboard import SummaryWriterfrom model import *from torch import nnfrom torch.utils.data import DataLoader# 准备数据集,CIFAR10 数据集是PIL Image,要转换为tensor数据类型train_data = torchvision.datasets.CIFAR10(root="../data",train=True,transform=torchvision.transforms.ToTensor(),download=True)test_data = torchvision.datasets.CIFAR10(root="../data",train=False,transform=torchvision.transforms.ToTensor(),download=True)# 看一下训练数据集和测试数据集都有多少张(如何获得数据集的长度)train_data_size = len(train_data) # length 长度test_data_size = len(test_data)# 如果train_data_size=10,那么打印出的字符串为:训练数据集的长度为:10print("训练数据集的长度为:{}".format(train_data_size)) # 字符串格式化,把format中的变量替换{}print("测试数据集的长度为:{}".format(test_data_size))# 利用 DataLoader 来加载数据集train_dataloader = DataLoader(train_data,batch_size=64)test_dataloader = DataLoader(test_data,batch_size=64)# 创建网络模型tudui = Tudui()# 创建损失函数loss_fn = nn.CrossEntropyLoss() # 分类问题可以用交叉熵# 定义优化器learning_rate = 0.01 # 另一写法:1e-2,即1x 10^(-2)=0.01optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate) # SGD 随机梯度下降# 设置训练网络的一些参数total_train_step = 0 # 记录训练次数total_test_step = 0 # 记录测试次数epoch = 10 # 训练轮数# 添加tensorboardwriter = SummaryWriter("../logs_train")for i in range(epoch): print("----------第{}轮训练开始-----------".format(i+1)) # i从0-9 # 训练步骤开始 for data in train_dataloader: imgs,targets = data outputs = tudui(imgs) loss = loss_fn(outputs,targets) # 优化器优化模型 optimizer.zero_grad() # 首先要梯度清零 loss.backward() # 反向传播得到每一个参数节点的梯度 optimizer.step() # 对参数进行优化 total_train_step += 1 if total_train_step % 100 ==0: # 逢百才打印记录 print("训练次数:{},loss:{}".format(total_train_step,loss.item())) writer.add_scalar("train_loss",loss.item(),total_train_step) # 测试步骤开始 total_test_loss = 0 total_accuracy = 0 with torch.no_grad(): # 无梯度,不进行调优 for data in test_dataloader: imgs,targets = data outputs = tudui(imgs) loss = loss_fn(outputs,targets) # 该loss为部分数据在网络模型上的损失,为tensor数据类型 # 求整体测试数据集上的误差或正确率 total_test_loss = total_test_loss + loss.item() # loss为tensor数据类型,而total_test_loss为普通数字 accuracy = (outputs.argmax(1) == targets).sum() # 1:横向比较,==:True或False,sum:计算True或False个数 total_accuracy = total_accuracy + accuracy print("整体测试集上的Loss:{}".format(total_test_loss)) print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size)) # 正确率为预测对的个数除以测试集长度 writer.add_scalar("test_loss",total_test_loss,total_test_step) writer.add_scalar("test_accuracy",total_test_loss,total_test_step,total_test_step) total_test_step += 1 torch.save(tudui,"tudui_{}.pth".format(i)) # 每一轮保存一个结果 print("模型已保存")writer.close()
本文链接地址:https://www.jiuchutong.com/zhishi/298395.html 转载请保留说明!

上一篇:8种css居中实现的详细实现方式了(css各种居中)

下一篇:JavaWeb 项目 --- 表白墙 和 在线相册(javaweb项目开发流程)

  • 微信收付款二维码怎么设置密码(微信收付款二维码截图可以用吗)

    微信收付款二维码怎么设置密码(微信收付款二维码截图可以用吗)

  • 苹果手机界面变黑色改成白色怎么办(苹果手机界面变成搜索怎么办)

    苹果手机界面变黑色改成白色怎么办(苹果手机界面变成搜索怎么办)

  • 华为平板c3什么时候出的(华为平板哪款最好)

    华为平板c3什么时候出的(华为平板哪款最好)

  • 机械硬盘装在机箱哪里(机械硬盘装在机箱位置图)

    机械硬盘装在机箱哪里(机械硬盘装在机箱位置图)

  • qq小秘密为什么会被删除(qq小秘密为什么没有了)

    qq小秘密为什么会被删除(qq小秘密为什么没有了)

  • 华为nova7pro有指纹解锁吗(华为nova7pro有指示灯吗)

    华为nova7pro有指纹解锁吗(华为nova7pro有指示灯吗)

  • 微信运动几点发消息(微信运动几点发布)

    微信运动几点发消息(微信运动几点发布)

  • 无雾加湿器和有雾加湿器区别(无雾加湿器和有雾加湿器哪个寿命长)

    无雾加湿器和有雾加湿器区别(无雾加湿器和有雾加湿器哪个寿命长)

  • 小米cc9e支持无线快充吗(小米cc9e支持5gwifi)

    小米cc9e支持无线快充吗(小米cc9e支持5gwifi)

  • 支付宝获取手机定位服务的权限在哪里(支付宝获取手机定位服务的权限在哪里设置)

    支付宝获取手机定位服务的权限在哪里(支付宝获取手机定位服务的权限在哪里设置)

  • plus版和普通版的区别(plus版是完整版吗)

    plus版和普通版的区别(plus版是完整版吗)

  • 钉钉可以匿名发消息吗(钉钉匿名发消息)

    钉钉可以匿名发消息吗(钉钉匿名发消息)

  • 爱奇艺怎么设置多人使用(爱奇艺怎么设置只看他)

    爱奇艺怎么设置多人使用(爱奇艺怎么设置只看他)

  • 红米k20pro取消掉上划搜索(红米k20prohd怎么关闭)

    红米k20pro取消掉上划搜索(红米k20prohd怎么关闭)

  • iphone11是双卡吗(苹果11第二张卡装哪里)

    iphone11是双卡吗(苹果11第二张卡装哪里)

  • 酷我音乐如何k歌(酷我音乐如何看听歌时间)

    酷我音乐如何k歌(酷我音乐如何看听歌时间)

  • iphone8p运行内存多大(iphone8p运行内存无故占满)

    iphone8p运行内存多大(iphone8p运行内存无故占满)

  • 抖音里面说话配音在哪(抖音里面说话配音怎么弄)

    抖音里面说话配音在哪(抖音里面说话配音怎么弄)

  • 小度音箱如何控制灯(小度音箱如何控制空调)

    小度音箱如何控制灯(小度音箱如何控制空调)

  • 华为屏幕一会亮一会暗(华为手机怎么屏幕一会亮一会不亮)

    华为屏幕一会亮一会暗(华为手机怎么屏幕一会亮一会不亮)

  • 怎么解决投屏延迟(如何解决投屏延迟)

    怎么解决投屏延迟(如何解决投屏延迟)

  • 华为方舟编译器怎么用(华为方舟编译器概念股)

    华为方舟编译器怎么用(华为方舟编译器概念股)

  • 手机屏幕颜色怎么调(手机屏幕颜色怎么调回正常)

    手机屏幕颜色怎么调(手机屏幕颜色怎么调回正常)

  • 文件类型怎么选择所有文件详细教程(文件类型设置)

    文件类型怎么选择所有文件详细教程(文件类型设置)

  • 如何自己搭建一个ai画图系统? 从0开始云服务器部署novelai(如何自己搭建一个邮箱服务器)

    如何自己搭建一个ai画图系统? 从0开始云服务器部署novelai(如何自己搭建一个邮箱服务器)

  • VUE项目运行失败原因以及解决办法(以vscode为例)(vue项目运行报错)

    VUE项目运行失败原因以及解决办法(以vscode为例)(vue项目运行报错)

  • 委托加工业务中,委托方是纳税义务人
  • 外商投资合伙企业的性质与特征
  • 持有至到期投资核算内容
  • 怎么算毛利润计算公式
  • 砖厂开票员的工作流程
  • 小企业会计准则没有以前年度损益调整科目
  • 小规模公司初期注销流程
  • 会计账簿账目核对要求包括哪些
  • 支付无法取得发票的赔偿金可否税前扣除
  • 企业所得税广告费结转先扣哪一年
  • 国有资产划转如何做账
  • 公司的样品一般怎么处理
  • 采购差价构成犯罪吗
  • 年终奖不走工资走存单,需要缴税吗?
  • 预缴增值税预缴的城建税怎么申报
  • 无法执行合同的说明函
  • 提成工资可以扣发吗?
  • 税控盘维护费抵减分录
  • 个人房补申请书怎么办
  • 工程已完工又发生了成本怎么处理
  • 城建税教育附加税的会计分录
  • 定期定额自行申报表计税依据
  • 零申报企业所得税
  • 物业公司安装监控为了什么
  • 关于递延所得税资产负债的表述
  • 出纳现金日记账怎么记账
  • 工业投资额是指什么
  • 出口退税挂靠业务如何做帐?
  • win10禁止使用网络
  • 无偿赠送的原材料怎么处理
  • 最新w10系统专业版
  • u盘的重装系统
  • 收到加盟费怎么入账
  • 马齿笕对什么病最有效?
  • phpstan
  • 无票收入什么时候确认收入
  • 发票开具有误拒收后销售方如何处理?
  • 给工程项目买保险是选哪个保险公司
  • php输出隔行变色的表格
  • 固定资产改造费用化账务处理
  • 个税网上申报流程视频
  • 小微企业所得税如何填报
  • 业务招待费可以开专票抵扣吗
  • 单位购日用品计提折旧吗
  • 什么是完税证明?完税证明丢了怎么办公司
  • 机动车发票哪几联 做帐
  • 税收罚款支出计入其他应付款吗对吗
  • 待处理财产损益期末结转到哪里
  • php自动压缩图片
  • SqlServer2012中First_Value函数简单分析
  • 小规模纳税人减半征收的六税两费
  • 日常费用报销表格
  • 单位缴纳工会经费有什么用
  • 用信用卡消费扣谁的手续费
  • 外贸企业的汇率怎么算
  • 红字发票开错了已上传如何作废?
  • 高新企业 要求
  • 企业的固定资产由于技术进步等原因
  • 安全生产费计提和使用
  • 长期无法收回的应收账款如何处理
  • win8系统怎么关闭投影
  • windows91
  • 组装机没有装系统开机会怎么样
  • 三星笔记是干什么用的
  • win7系统怎么用
  • win7电脑开机声音怎么改
  • securecrt教程
  • win7用户账户控制设置电脑重启后恢复
  • mac 活动监视器在哪里
  • mom.exe是什么程序
  • win10搜索不到无线网卡
  • win7系统运行怎么打开
  • win10系统开机后任务栏无响应怎么解决
  • css如何实现
  • unity局域网多人游戏
  • node.js的安装步骤
  • Python 正则表达式入门(中级篇)
  • 小规模纳税人开1%普票怎么报税
  • 全国退休人员有几多人
  • 90平房子税
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设