位置: IT常识 - 正文

SSD训练数据集流程(学习记录)(ssd训练自己的数据集pytorch)

编辑:rootadmin
SSD训练数据集流程(学习记录)

推荐整理分享SSD训练数据集流程(学习记录)(ssd训练自己的数据集pytorch),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:ssd 训练,ssd模型训练,ssdd数据集,svm训练数据集,ssd训练loss降不下来,ssd训练loss降不下来,ssd 训练,ssd 训练,内容如对您有帮助,希望把文章链接给更多的朋友!

关于理论部分我看的是b站“霹雳吧啦Wz”的SSD理论讲解,作为入门小白表示能听懂,需要的同学可以自行观看

目录

1.训练环境

2.训练步骤


1.训练环境

我的环境是win11+anaconda+python3.6.13+pytorch1.10.2+cuda11.6

2.训练步骤

(1)下载SSD源码

可到github进行下载

GitHub - amdegroot/ssd.pytorch: A PyTorch Implementation of Single Shot MultiBox Detector

(2)下载模型文件

VGG16_reducedfc.pth预训练模型下载地址:https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth

将下载的模型文件放置于ssd源码目录中  wights/vgg16_reducedfc.pth

(3)数据集准备

与大多数训练模型一样,ssd支持的训练格式为VOC和coco,这里采用voc2007作为演示,制作自己的数据集以及labimg的使用可自行观看yolo数据集标注软件安装+使用流程_道人兄的博客-CSDN博客_yolo数据集标注工具

voc2007的具体下载方式我也不多赘述,网络上百度也有,或者直接看我之前写的也有提到使用Faster—RCNN训练数据集流程(学习记录)_道人兄的博客-CSDN博客

将下载后的voc2007数据集放置于./data/VOCdevkit/中

 然后到ssd.pytorch-master/data/中的voc0712.py进行修改其中的VOC_ROOT = osp.join(HOME, "data/VOCdevkit/"),他这里的HOME老是读取我的C盘位置,所以一直报错,我直接把数据集的绝对路径写上去了就没报错

SSD训练数据集流程(学习记录)(ssd训练自己的数据集pytorch)

将 voc0712.py文件中VOCDetection类的__init__函数,将image_sets修改为[('2007', 'train'), ('2007', 'val'),('2007','test')],修改后的结果如下。

def __init__(self, root, image_sets=[('2007', 'train'), ('2007', 'val'),('2007','test')], transform=None, target_transform=VOCAnnotationTransform(), dataset_name='VOC0712'):

其中如果是训练自己的数据集,记得修改voc0712.py文件中的VOC_CLASSES 变量。例如,将VOC_CLASSES修改为person类,注意如果只有一类则需要加方括号,修改后的结果如下。

VOC_CLASSES = [('person')

如果训练自己的数据集,还需要修改config.py文件中的voc字典变量。将其中的num_classes修改为2(以person为例)(背景类+你训练集的种类个数),第一次调试时可以将max_iter调小至1000,修改后的结果如下。

voc = { 'num_classes': 2, 'lr_steps': (80000, 100000, 120000), 'max_iter': 1000, 'feature_maps': [38, 19, 10, 5, 3, 1], 'min_dim': 300, 'steps': [8, 16, 32, 64, 100, 300], 'min_sizes': [30, 60, 111, 162, 213, 264], 'max_sizes': [60, 111, 162, 213, 264, 315], 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]], 'variance': [0.1, 0.2], 'clip': True, 'name': 'VOC',}

最后一步,把coco_labels.txt放在ssd.pytorch-master/data/coco/目录下,也可以通过修改coco.py文件中的COCO_ROOT = osp.join(HOME, 'data/coco/')来指定存放路径。

(4)修改源码

①修改ssd.py文件中SSD类的__init__函数和forward函数,修改后的结果如下。

if phase == 'test':self.softmax = nn.Softmax(dim=-1) self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)修改为:if phase == 'test':self.softmax = nn.Softmax()self.detect = Detect()if self.phase == "test":output = self.detect( loc.view(loc.size(0), -1, 4), # loc preds self.softmax(conf.view(conf.size(0), -1, self.num_classes)), # conf preds self.priors.type(type(x.data)) # default boxes )修改为:if self.phase == "test":output = self.detect.apply(21, 0, 200, 0.01, 0.45, loc.view(loc.size(0), -1, 4), # loc preds self.softmax(conf.view(-1,21)), # conf predsself.priors.type(type(x.data)) # default boxes)

②修改train.py中187至189行代码,原因是.data[0]写法适用于低版本Pytorch,否则会出现IndexError:invalid index of a 0-dim tensor...错误,修改后的结果如下。

loc_loss += loss_l.item()conf_loss += loss_c.item()if iteration % 10 == 0: print('timer: %.4f sec.' % (t1 - t0))print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.item()), end=' ')

③交换layers/modules/multibox_loss.py中97行和98代码位置,否则会出现IndexError: The shape of the mask [14, 8732] at index 0does...错误,修改后的结果如下。

loss_c = loss_c.view(num, -1)loss_c[pos] = 0 # filter out pos boxes for now

④根据自己的需要对train.py中预训练模型、batch_size、学习率、模型名字和模型保存的次数等参数进行修改。建议学习率修改为1e-4(原因是原版使用1e-3可能会出现loss为nan情况),第一次调试时可以修改为每迭代100次保存,方便调试。

# 加载模型初始参数parser = argparse.ArgumentParser( description='Single Shot MultiBox Detector Training With Pytorch')train_set = parser.add_mutually_exclusive_group()# 默认加载VOC数据集parser.add_argument('--dataset', default='VOC', choices=['VOC', 'COCO'], type=str, help='VOC or COCO')# 设置VOC数据集根路径parser.add_argument('--dataset_root', default=VOC_ROOT, help='Dataset root directory path')# 设置预训练模型vgg16_reducedfc.pthparser.add_argument('--basenet', default='vgg16_reducedfc.pth', help='Pretrained base model')# 设置批大小,根据自己显卡能力设置,默认为32,此处我改为16parser.add_argument('--batch_size', default=16, type=int, help='Batch size for training')# 是否恢复中断的训练,默认不恢复parser.add_argument('--resume', default=None, type=str, help='Checkpoint state_dict file to resume training from')# 恢复训练iter数,默认从第0次迭代开始parser.add_argument('--start_iter', default=0, type=int, help='Resume training at this iter')# 数据加载线程数,根据自己CPU个数设置,默认为4parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading')# 是否使用CUDA加速训练,默认开启,如果没有GPU,可改成False直接用CPU训练parser.add_argument('--cuda', default=True, type=str2bool, help='Use CUDA to train model')# 学习率,默认0.001parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate')# 最佳动量值,默认0.9(动量是梯度下降法中一种常用的加速技术,用于加速梯度下降,减少收敛耗时)parser.add_argument('--momentum', default=0.9, type=float, help='Momentum value for optim')# 权重衰减,即正则化项前面的系数,用于防止过拟合;SGD,即mini-batch梯度下降parser.add_argument('--weight_decay', default=1e-4, type=float, help='Weight decay for SGD')# gamma更新,默认值0.1parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD')# 使用visdom将训练过程loss图像可视化parser.add_argument('--visdom', default=False, type=str2bool, help='Use visdom for loss visualization')# 权重保存位置,默认存在weights/下parser.add_argument('--save_folder', default='weights/', help='Directory for saving checkpoint models')args = parser.parse_args()if iteration != 0 and iteration % 100 == 0:print('Saving state, iter:', iteration)torch.save(ssd_net.state_dict(), 'weights/ssd300_VOC_' + repr(iteration) + '.pth')

⑤因为pytorch1.9以上版本在这份源代码中并不适用,一旦运行cuda方面会报错如下:

RuntimeError: Expected a ‘cuda‘ device type for generator but found ‘cpu‘

参考github上的解决方法,有两种方法可成功运行:

第一种是重装pytorch1.8版本,就可以正常运行,但我觉得太麻烦了

第二种是修改源码:

在位于 anaconda 或任何地方的文件“site-packages/torch/utils/data/sampler.py”中。

[修改第 116 行]:generator = torch.Generator()改成generator = torch.Generator(device='cuda')[修改第 126 行]:yield from torch.randperm(n, generator=generator).tolist()改成yield from torch.randperm(n, generator=generator, device='cuda').tolist()

在train.py文件中,data.DataLoader处进行添加generator

data_loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=detection_collate, pin_memory=True, generator=torch.Generator(device='cuda'))

(5)运行train.py,如下图

参考资料:

SSD训练自己的数据集(pytorch版)_Kellenn的博客-CSDN博客_ssd训练自己的数据集pytorch

【目标检测实战】Pytorch—SSD模型训练(VOC数据集) - 知乎 (zhihu.com)

2.1SSD算法理论_哔哩哔哩_bilibili

本文链接地址:https://www.jiuchutong.com/zhishi/298420.html 转载请保留说明!

上一篇:超参数调优框架optuna(可配合pytorch)(超参数设置)

下一篇:点云数据的语义分割算法综述总结大全(传统方法+基于深度学习的方法)(什么叫点云数据)

  • iphone13照片排序设置(iphone相簿照片排序)

    iphone13照片排序设置(iphone相簿照片排序)

  • oppo一键锁屏怎么设置(oppo一键锁屏怎么添加到桌面)

    oppo一键锁屏怎么设置(oppo一键锁屏怎么添加到桌面)

  • 网易云怎么下载mp3格式的音乐(网易云怎么下载mv)

    网易云怎么下载mp3格式的音乐(网易云怎么下载mv)

  • cpu是什么处理器(苹果7cpu是什么处理器)

    cpu是什么处理器(苹果7cpu是什么处理器)

  • 抖音帮别人推速对方知道吗(帮别人dou速推会被别人知道吗)

    抖音帮别人推速对方知道吗(帮别人dou速推会被别人知道吗)

  • mac多选的快捷键(mac多选的快捷键是什么)

    mac多选的快捷键(mac多选的快捷键是什么)

  • 手机相册图片怎么剪切不要的部分(手机相册图片怎么拼图)

    手机相册图片怎么剪切不要的部分(手机相册图片怎么拼图)

  • 苹果6可以用5g网络吗(iphone 6可以用5g吗)

    苹果6可以用5g网络吗(iphone 6可以用5g吗)

  • 红米note2为什么那么耗电(红米note2为什么那么便宜)

    红米note2为什么那么耗电(红米note2为什么那么便宜)

  • 华为屏幕暗到看不见(华为屏幕暗到看不清东西)

    华为屏幕暗到看不见(华为屏幕暗到看不清东西)

  • dp75sdi是第几代(亚马逊dp75sdi是第几代)

    dp75sdi是第几代(亚马逊dp75sdi是第几代)

  • 华为mate30需要钢化膜吗(mate30带钢化膜吗)

    华为mate30需要钢化膜吗(mate30带钢化膜吗)

  • qq相册怎么上传照片(qq相册怎么上传大视频文件)

    qq相册怎么上传照片(qq相册怎么上传大视频文件)

  • 锐龙r73750h相当于i几(锐龙r7 3750h性能)

    锐龙r73750h相当于i几(锐龙r7 3750h性能)

  • 新装路由器已连接不可上网怎么回事(新装路由器已连接不上网)

    新装路由器已连接不可上网怎么回事(新装路由器已连接不上网)

  • 微信什么时候显示对方正在输入(微信什么时候显示忙线未接听)

    微信什么时候显示对方正在输入(微信什么时候显示忙线未接听)

  • vivo浏览器个人中心在哪(vivo浏览器个人中心在哪里找)

    vivo浏览器个人中心在哪(vivo浏览器个人中心在哪里找)

  • switch屏幕多大(switch的屏幕尺寸)

    switch屏幕多大(switch的屏幕尺寸)

  • ios12怎么手动清理缓存(ios12怎么清理缓存)

    ios12怎么手动清理缓存(ios12怎么清理缓存)

  • wps文档怎么查找文字(wps文档怎么查找重复项)

    wps文档怎么查找文字(wps文档怎么查找重复项)

  • 安居客发布的信息怎么删除(安居客发布的信息怎么修改)

    安居客发布的信息怎么删除(安居客发布的信息怎么修改)

  • 数字信号是一种数字式的什么信号(数字信号是一种什么脉冲序列)

    数字信号是一种数字式的什么信号(数字信号是一种什么脉冲序列)

  • 抖音长视频怎么看(抖音长视频怎么保存观看进度)

    抖音长视频怎么看(抖音长视频怎么保存观看进度)

  • qq影音如何放大视频(qq影音怎么调整视频比例)

    qq影音如何放大视频(qq影音怎么调整视频比例)

  • 链接怎么群发(链接怎么群发微信好友100个)

    链接怎么群发(链接怎么群发微信好友100个)

  • oppoa5指纹识别在哪(oppoa535g指纹)

    oppoa5指纹识别在哪(oppoa535g指纹)

  • 腾讯qq宠物停运补偿在哪领(qq宠物停服前三小时,最后的回忆!)

    腾讯qq宠物停运补偿在哪领(qq宠物停服前三小时,最后的回忆!)

  • 动态壁纸怎么设置(动态壁纸怎么设置锁屏)

    动态壁纸怎么设置(动态壁纸怎么设置锁屏)

  • 如何使用(扫描)二维码进行登录(如何使用扫描王)

    如何使用(扫描)二维码进行登录(如何使用扫描王)

  • 丢失增值税专用发票最新规定
  • 应交消费税的税目
  • 利润10万企业所得税多少
  • 企业需要政府哪方面政策支持
  • 车保险备注栏车船税如何记账
  • 现金流量表抵消分录
  • 其他应收款有哪些情况
  • 减免税控盘增值税纳税申报
  • 工业企业库存商品的初始入账成本
  • 已提足折旧的固定资产残值怎么处理
  • 企业所得税的应纳税所得额的扣除项目有哪些
  • 支付保险费发票怎么入账
  • 上市公司个税手续流程
  • 年收入超过12万什么时候申报
  • 企业房产税如何计算缴纳
  • 关于个体工商户的法律规定及司法解释
  • 建材公司将自产产品卖出
  • 接手新公司涉税问题分析
  • 总资产周转率ttm
  • 工资计税基数
  • 应收账款未计提坏账,但是确实收不回来
  • 一般纳税人购车可以抵扣多少税
  • 企业的其他业务是什么
  • 企业在外地的房产怎么办
  • 个税申报时个人怎么填
  • 固定资产更改折旧年限怎么账务处理
  • 收到退还的工会经费进什么科目
  • 电脑开始菜单在右边怎么调回来
  • vnisedit 打包
  • 笔记本电脑系统更新好不好
  • 企业所得税汇算清缴表
  • 收到厂家返利怎么做账务处理
  • 发票加盖发票章可以吗
  • 新事业单位会计准则
  • 代销商品受托方怎么做账
  • Uniapp使用$base方法
  • 育空河24102
  • php封装数据库连接
  • 动态设置窗体记录源属性
  • js实现拖拽选区的功能
  • metric命令
  • 个人所得税银行卡未实名认证是什么意思
  • 个人所得税汇算清缴时间
  • 非营利组织会计就是用于确认、计量
  • php访问mysql数据库函数
  • 返利销售的增值税怎么算
  • 如果没有抄税就申报了
  • mysql两张表差异数据
  • 对公账户有法律效力吗
  • 私车公用产生的费用算不算在公务用车运行维护费中
  • 建筑企业的安全技术措施
  • 出口退税的会计处理
  • 事业单位项目结算审计报告
  • 折扣怎么写会计分录
  • 长期股权投资权益法初始成本的确定
  • mysql8.0环境配置
  • 建立索引mysql
  • sqlserver 17051解决方案
  • win7电脑初始化
  • win8系统升级win8.1
  • xp远程连接win7
  • Qoeloader.exe - Qoeloader是什么进程 有什么用
  • win8打开ie
  • opengl mesa
  • nodejs实战教程
  • angular做app
  • django发送请求
  • chrome excel
  • jquery弹窗弹出一个页面
  • 安卓手机管家推荐
  • unity最新教程
  • 四川税务局网上办税
  • 河东区地税局上班时间
  • 如何做好税收工作推动税收事业创新发展
  • 国家税务总局服务平台
  • 陕西省地方税务局公告2016年第1号
  • 降低税率的坏处
  • 国外寄回来的奶粉被海关查到剪开,快递公司怎么处理
  • 铜陵职业技术学院专业
  • 党建税收宣传
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设