位置: IT常识 - 正文

Python深度学习实战:人脸关键点(15点)检测pytorch实现

编辑:rootadmin
Python深度学习实战:人脸关键点(15点)检测pytorch实现 引言

推荐整理分享Python深度学习实战:人脸关键点(15点)检测pytorch实现,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

Python深度学习实战:人脸关键点(15点)检测pytorch实现

人脸关键点检测即对人类面部若干个点位置进行检测,可以通过这些点的变化来实现许多功能,该技术可以应用到很多领域,例如捕捉人脸的关键点,然后驱动动画人物做相同的面部表情;识别人脸的面部表情,让机器能够察言观色等等。

如何检测人脸关键点

本文是实现15点的检测,至于N点的原理都是一样的,使用的算法模型是深度神经网络,使用CV也是可以的。

如何检测

这个问题抽象出来,就是一个使用神经网络来进行预测的功能,只不过输出是15个点的坐标,训练数据包含15个面部的特征点和面部的图像(大小为96x96),15个特征点分别是:left_eye_center, right_eye_center, left_eye_inner_corner, left_eye_outer_corner, right_eye_inner_corner, right_eye_outer_corner, left_eyebrow_inner_end, left_eyebrow_outer_end, right_eyebrow_inner_end, right_eyebrow_outer_end, nose_tip, mouth_left_corner, mouth_right_corner, mouth_center_top_lip, mouth_center_bottom_lip 因此神经网络需要学习一个从人脸图像到15个关键点坐标间的映射。

使用的网络结构

在本文中,我们使用深度神经网络来实现该功能,基本卷积块使用Google的Inception网络,也就是使用GoogLeNet网络,该结构的网络是基于卷积神经网络来改进的,是一个含有并行连接的网络。 众所周知,卷积有滤波、提取特征的作用,但到底采用多大的卷积来提取特征是最好的呢?这个问题没有确切的答案,那就集百家之长:使用多个形状不一的卷积来提取特征并进行拼接,从而学习到更为丰富的特征;特别是里面加上了1x1的卷积结构,能够实现跨通道的信息交互和整合(其本质就是在多个channel上的线性求和),同时能在feature map通道数上的降维(读者可以验证计算一下,能够极大减少卷积核的参数),也能够增加非线性映射次数使得网络能够更深。 下面是Inception块的示意图: 整个GoogLeNet的结构如下所示: 接下来是代码实现部分,后续作者会补充神经网络的相关原理知识,若对此感兴趣的读者也可继续关注支持~

代码实现import torch as tcfrom torch import nnfrom torch.nn import functional as Ffrom torch.utils.data import DataLoaderfrom torch.utils.data import TensorDatasetimport numpy as npimport matplotlib.pyplot as pltimport pandas as pdfrom sklearn.utils import shuffle# 对图片像素的处理def proFunc1(data,testFlag:bool=False) -> tuple: data['Image'] = data['Image'].apply(lambda im: np.fromstring(im, sep=' ')) # 处理na data = data.dropna() # 神经网络对数据范围较为敏感 /255 将所有像素都弄到[0,1]之间 X = np.vstack(data['Image'].values) / 255 X = X.astype(np.float32) # 特别注意 这里要变成 n channle w h 要跟卷积第一层相匹配 X = X.reshape(-1, 1,96, 96) # 等会神经网络的输入层就是 96 96 黑白图片 通道只有一个 # 只有训练集才有y 测试集返回一个None出去 if not testFlag: y = data[data.columns[:-1]].values # 规范化 y = (y - 48) / 48 X, y = shuffle(X, y, random_state=42) y = y.astype(np.float32) else: y = None return X,y# 工具类class UtilClass: def __init__(self,model,procFun,trainFile:str='data/training.csv',testFile:str='data/test.csv') -> None: self.trainFile = trainFile self.testFile = testFile self.trainData = None self.testData = None self.trainTarget = None self.model = model self.procFun = procFun @staticmethod def procData(data, procFunc ,testFlag:bool=False) -> tuple: return procFunc(data,testFlag) def loadResource(self): rawTrain = pd.read_csv(self.trainFile) rawTest = pd.read_csv(self.testFile) self.trainData , self.trainTarget = self.procData(rawTrain,self.procFun) self.testData , _ = self.procData(rawTest,self.procFun,testFlag=True) def getTrain(self): return tc.from_numpy(self.trainData), tc.from_numpy(self.trainTarget) def getTest(self): return tc.from_numpy(self.testData) @staticmethod def plotData(img, keyPoints, axis): axis.imshow(np.squeeze(img), cmap='gray') # 恢复到原始像素数据 keyPoints = keyPoints * 48 + 48 # 把keypoint弄到图上面 axis.scatter(keyPoints[0::2], keyPoints[1::2], marker='o', c='c', s=40)# 自定义的卷积神经网络class MyCNN(tc.nn.Module): def __init__(self,imgShape = (96,96,1),keyPoint:int = 15): super(MyCNN, self).__init__() self.conv1 = tc.nn.Conv2d(in_channels=1, out_channels =10, kernel_size=3) self.pooling = tc.nn.MaxPool2d(kernel_size=2) self.conv2 = tc.nn.Conv2d(10, 5, kernel_size=3) # 这里的2420是通过下面的计算得出的 如果改变神经网络结构了 # 需要计算最后的Liner的in_feature数量 输出是固定的keyPoint*2 self.fc = tc.nn.Linear(2420, keyPoint*2) def forward(self, x): # print("start----------------------") batch_size = x.size(0) # x = x.view((-1,1,96,96)) # print('after view shape:',x.shape) x = F.relu(self.pooling(self.conv1(x))) # print('conv1 size',x.shape) x = F.relu(self.pooling(self.conv2(x))) # print('conv2 size',x.shape) # print('end--------------------------') # 改形状 x = x.view(batch_size, -1) # print(x.shape) x = self.fc(x) # print(x.shape) return x# GoogleNet基本的卷积块class MyInception(nn.Module): def __init__(self,in_channels, c1, c2, c3, c4,) -> None: super().__init__() self.p1_1 = nn.Conv2d(in_channels, c1, kernel_size=1) self.p2_1 = nn.Conv2d(in_channels, c2[0], kernel_size=1) self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1) self.p3_1 = nn.Conv2d(in_channels, c3[0], kernel_size=1) self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2) self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) self.p4_2 = nn.Conv2d(in_channels, c4, kernel_size=1) def forward(self, x): p1 = F.relu(self.p1_1(x)) p2 = F.relu(self.p2_2(F.relu(self.p2_1(x)))) p3 = F.relu(self.p3_2(F.relu(self.p3_1(x)))) p4 = F.relu(self.p4_2(self.p4_1(x))) # 在通道维度上连结输出 return tc.cat((p1, p2, p3, p4), dim=1)# GoogLeNet的设计 此处参数结果google大量实验得出b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1))b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1), nn.ReLU(), nn.Conv2d(64, 192, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1))b3 = nn.Sequential(MyInception(192, 64, (96, 128), (16, 32), 32), MyInception(256, 128, (128, 192), (32, 96), 64), nn.MaxPool2d(kernel_size=3, stride=2, padding=1))b4 = nn.Sequential(MyInception(480, 192, (96, 208), (16, 48), 64), MyInception(512, 160, (112, 224), (24, 64), 64), MyInception(512, 128, (128, 256), (24, 64), 64), MyInception(512, 112, (144, 288), (32, 64), 64), MyInception(528, 256, (160, 320), (32, 128), 128), nn.MaxPool2d(kernel_size=3, stride=2, padding=1))b5 = nn.Sequential(MyInception(832, 256, (160, 320), (32, 128), 128), MyInception(832, 384, (192, 384), (48, 128), 128), nn.AdaptiveAvgPool2d((1,1)), nn.Flatten())uClass = UtilClass(model=None,procFun=proFunc1)uClass.loadResource()xTrain ,yTrain = uClass.getTrain()xTest = uClass.getTest()dataset = TensorDataset(xTrain, yTrain)trainLoader = DataLoader(dataset, 64, shuffle=True, num_workers=4)# 训练net并进行测试 由于显示篇幅问题 只能打印出极为有限的若干测试图片效果def testCode(net): optimizer = tc.optim.Adam(params=net.parameters()) criterion = tc.nn.MSELoss() for epoch in range(30): trainLoss = 0.0 # 这里是用的是mini_batch 也就是说 每次只使用mini_batch个数据大小来计算 # 总共有total个 因此总共训练 total/mini_batch 次 # 由于不能每组数据只使用一次 所以在下面还要使用一个for循环来对整体训练多次 for batchIndex, data in enumerate(trainLoader, 0): input_, y = data yPred = net(input_) loss = criterion(yPred, y) optimizer.zero_grad() loss.backward() optimizer.step() trainLoss += loss.item() # 只在每5个epoch的最后一轮打印信息 if batchIndex % 30 ==29 and not epoch % 5 : print("[{},{}] loss:{}".format(epoch + 1, batchIndex + 1, trainLoss / 300)) trainLoss = 0.0 # 测试 print("-----------test begin-------------") # print(xTest.shape) yPost = net(xTest) # print(yPost.shape) import matplotlib.pyplot as plt %matplotlib inline fig = plt.figure(figsize=(20,20)) fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05) for i in range(9,18): ax = fig.add_subplot(3, 3, i - 9 + 1, xticks=[], yticks=[]) uClass.plotData(xTest[i], y[i], ax) print("-----------test end-------------")if __name__ == "__main__": # 训练MyCNN网络 并可视化在9个测试数据的效果图 myNet = MyCNN() testCode(myNet) inception = nn.Sequential(b1, b2, b3, b4, b5, nn.Linear(1024, 30))testCode(inception)

本文使用的数据可在此找到两个data文件,本文有你帮助的话,就给个点赞关注支持一下吧!

本文链接地址:https://www.jiuchutong.com/zhishi/300379.html 转载请保留说明!

上一篇:【IIS搭建网站】本地电脑做服务器搭建web站点并公网访问「内网穿透」(iis搭建网站教程win10)

下一篇:(二)元学习算法MAML简介及代码分析(二元运算例子)

  • 怎么定时发送微信(怎么定时发送微信消息OPPO)

    怎么定时发送微信(怎么定时发送微信消息OPPO)

  • 华为手环丢了如何找回(华为手环丢了如何解除配对)

    华为手环丢了如何找回(华为手环丢了如何解除配对)

  • 快手直播口令红包怎么输入答案(快手直播口令红包怎么找到)

    快手直播口令红包怎么输入答案(快手直播口令红包怎么找到)

  • 三星s10降噪孔捅一下会坏吗

    三星s10降噪孔捅一下会坏吗

  • 华为怎么下载谷歌商店(华为怎么下载谷歌地图)

    华为怎么下载谷歌商店(华为怎么下载谷歌地图)

  • 一体手机进水晾多久开机(一体手机进水晾干后充电慢)

    一体手机进水晾多久开机(一体手机进水晾干后充电慢)

  • 什么是外播已转接来电(外播已转播来电是什么意思)

    什么是外播已转接来电(外播已转播来电是什么意思)

  • iPhone11一晚上充电可以吗(iphone11一整晚充电)

    iPhone11一晚上充电可以吗(iphone11一整晚充电)

  • win7触摸板不能用(win7触摸屏无法触摸)

    win7触摸板不能用(win7触摸屏无法触摸)

  • 在word中按什么键与工具栏上的复制按钮功能相同(在word中按什么按钮可以改变字符底纹)

    在word中按什么键与工具栏上的复制按钮功能相同(在word中按什么按钮可以改变字符底纹)

  • 电脑装完系统后还是不能进入windows(电脑装完系统后重启就进不了系统)

    电脑装完系统后还是不能进入windows(电脑装完系统后重启就进不了系统)

  • 滴滴半天不派单是为什么(滴滴今天突然不派单了)

    滴滴半天不派单是为什么(滴滴今天突然不派单了)

  • 微信无法接受视频通话(微信无法接受视频片子)

    微信无法接受视频通话(微信无法接受视频片子)

  • psg1218是千兆路由器吗(psg1218 k2路由器是百兆还是千兆)

    psg1218是千兆路由器吗(psg1218 k2路由器是百兆还是千兆)

  • qqvip怎么升级到vip2(qqvlp怎么升级)

    qqvip怎么升级到vip2(qqvlp怎么升级)

  • p40是什么屏幕(p40是什么屏幕供应商)

    p40是什么屏幕(p40是什么屏幕供应商)

  • 网易云音乐古典专区在哪里(网易云音乐古典音乐数据)

    网易云音乐古典专区在哪里(网易云音乐古典音乐数据)

  • sata2和sata3外观区别(sata2和sata3接口区别大吗)

    sata2和sata3外观区别(sata2和sata3接口区别大吗)

  • qq特别关心有什么功能(qq特别关心有什么特殊的吗)

    qq特别关心有什么功能(qq特别关心有什么特殊的吗)

  • ipad air3什么时候出(ipad air3什么时候发布的)

    ipad air3什么时候出(ipad air3什么时候发布的)

  • iphone7掉水里能用吗(iphone7掉进水里)

    iphone7掉水里能用吗(iphone7掉进水里)

  • 夸克链信怎么注销(夸克链信注册)

    夸克链信怎么注销(夸克链信注册)

  • 手机信息显示怎样设置(手机信息显示怎么取消掉)

    手机信息显示怎样设置(手机信息显示怎么取消掉)

  • 荣耀v20关闭后台不显示(荣耀v20怎么关闭)

    荣耀v20关闭后台不显示(荣耀v20怎么关闭)

  • 华为nova3长多少厘米(华为nova3e手机长度是多少)

    华为nova3长多少厘米(华为nova3e手机长度是多少)

  • 手机互传在哪里(oppo手机互传在哪里)

    手机互传在哪里(oppo手机互传在哪里)

  • dvi接口有几种(dvi接口有几种类型)

    dvi接口有几种(dvi接口有几种类型)

  • 抖音对方把我拉黑了还能刷到我吗(抖音对方把我拉黑了我看他主页他访客有我记录吗)

    抖音对方把我拉黑了还能刷到我吗(抖音对方把我拉黑了我看他主页他访客有我记录吗)

  • 时间财富原名叫什么(时间财富app)

    时间财富原名叫什么(时间财富app)

  • 待抵扣进项税的账务处理
  • 纳税人识别号是什么哪里可以查到
  • 发票签字有什么用
  • 母子公司的关联交易怎么看
  • 个人劳务费的免税政策
  • 技术维护费计入哪里
  • 缴纳的专利年费能退吗
  • 旧设备换新设备文案
  • 债转股企业所得税资本公积
  • 未结转损益可以结账吗
  • 小规模代扣代缴个税会计分录
  • 公司贷款买车后影响公司收购吗
  • 为什么中国没有工业革命
  • 税控盘抵扣增值税怎么做账
  • 简易征收货物的运费
  • 小规模纳税人需要汇算清缴吗
  • 物业公司代收水费亏损谁承担
  • 维修属于劳务还是劳务
  • 外贸企业零退税怎么算
  • 预收房款属于什么科目
  • 外贸企业出口退税申报期限
  • 跨年会计分录错误
  • 破产清算应付账款
  • 如何做预估成本
  • 稿酬计入工资所得吗
  • 职工教育经费包括餐费吗
  • php获取访问者qq
  • 预计负债属于什么类
  • info.exe
  • 外籍人员个人所得税政策2023规定
  • 城市里创业
  • 大型绿萝的养殖方法
  • 午夜太阳的意思
  • ROS2+cartographer+激光雷达+IMU里程计数据融合(robot_locazation) 建图
  • java实现电子发票
  • phpcms怎么样
  • 开具劳务发票需要提供什么资料?
  • 帝国cms导航站模板
  • python一元二次方程求根
  • php实现站内消息推送
  • 以前年度损益调整会计分录
  • 一般纳税人零申报怎么报税步骤
  • 购买铝材会计分录
  • 固定资产的弃置费用
  • 工程类企业存货
  • 其他应付款二级明细科目有哪些
  • 餐饮业税务申报
  • 资产负债表中没有专项储备怎么填写
  • sql中循环语句怎么写
  • 税控盘有什么作用
  • 电影院租金
  • 工程的挂靠取得收入怎么做账?
  • 生产成本人工费结转
  • 管理费用和销售费用属于什么科目
  • 经营性应付项目减少对经营活动现金
  • 微软系统无法开机怎么办
  • 怎么在bios中开启cs1
  • wbs是什么文件
  • 电脑百度搜索
  • 1sass.exe是什么程序
  • vmware15.5安装mac
  • win7打开文件夹都是独立的窗口
  • window10稳定版
  • css实战手册
  • unity粒子制作ui特效
  • python字典常用操作以及字典的嵌套
  • perl use of uninitialized
  • 批处理杀死进程
  • python的pip安装命令
  • python基础总结
  • 一道关于医用口罩的数学题初二
  • jquery 输入框输入完触发事件
  • javascript总结笔记
  • shell脚本判断命令是否执行成功
  • javascript怎么学
  • javascript playground
  • 关联公司销售
  • 00后先进人物事迹简介
  • 天猫主体变更是什么意思
  • 社保费是国税还是地税
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设