位置: IT常识 - 正文

Pytorch中的grid_sample算子功能解析(pytorch中的数据类型)

编辑:rootadmin
Pytorch中的grid_sample算子功能解析

推荐整理分享Pytorch中的grid_sample算子功能解析(pytorch中的数据类型),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch中的loss函数,pytorch中的forward函数,pytorch中的数据类型,pytorch中的tensor,pytorch中的view函数,pytorch中的张量,pytorch中的loss函数,pytorch中的tensor,内容如对您有帮助,希望把文章链接给更多的朋友!

         pytorch中的grid_sample是一种特殊的采样算法。

调用接口为:

torch.nn.functional.grid_sample(input,grid,mode='bilinear',padding_mode='zeros',align_corners=None)。

         input参数是输入特征图tensor,也就是特征图,可以是四维或者五维张量,以四维形式为例(N,C,Hin,Win),N可以理解为Batch_size,C可以理解为通道数,Hin和Win也就是特征图高和宽。

         grid包含输出特征图特征图的格网大小以及每个格网对应到输入特征图的采样点位,对应四维input,其张量形式为(N,Hout,Wout,2),其中最后一维大小必须为2,如果输入为五维张量,那么最后一维大小必须为3。为什么最后一维必须为2或者3?因为grid的最后一个维度实际上代表一个坐标(x,y)或者(xy,z),对应到输入特征图的二维或三维特征图的坐标维度,xy取值范围一般为[-1,1],该范围映射到输入特征图的全图。

         mode为选择采样方法,有三种内插算法可选,分别是'bilinear'双线性差值、'nearest'最邻近插值、'bicubic' 双三次插值。

Pytorch中的grid_sample算子功能解析(pytorch中的数据类型)

         padding_mode为填充模式,即当(x,y)取值超过输入特征图采样范围,返回一个特定值,有'zeros' 、 'border' 、 'reflection'三种可选,一般用zero。

         align_corners为bool类型,指设定特征图坐标与特征值对应方式,设定为TRUE时,特征值位于像素中心。

         要理解grid_sample是如何工作的,最好就是进行简单的复现。假设输入shape为(N,C,H,W),grid的shape设定为(N,H,W,2),以双线性差值为例进行处理。首先根据input和grid设定,输出特征图tensor的shape为(N,C,H,W),输出特征图上每一个cell上的值由grid最后一维(x,y)确定。那么如何计算输出tensor上每一个点的值?首先,通过(x,y)找到输入特征图上的采样位置,由于xy取值范围为[-1,1],为了便于计算,先将xy取值范围调整为[0,1]。通过(w-1)*(x+1)/2、(wh-1)*(y+1)/2将xy映射为输入特征图的具体坐标位置。将xy映射到特征图实际坐标后,取该坐标附近四个角点特征值,通过四个特征值坐标与采样点坐标相对关系进行双线性插值,得到采样点的值。

注意:xy映射后的坐标可能是输入特征图上任意位置。假设输出特征图上(2,2)坐标位置上的值采样位置可能为输入特征图上(3,4)位置,xy越小越靠近输入特征图左上角,越大则越靠近右下角。

         基于上面的思路,可以进行一个简单的自定义实现。根据指定shape生成input和grid,使用pytorch中的grid_sample算子生成output。之后取grid中的第一个位置中的xy,根据xy从input中通过双线性插值计算出output第一个位置的值。

import torchimport numpy as npdef grid_sample(input, grid): N, C, H_in, W_in = input.shape N, H_out, W_out, _ = grid.shape output = np.random.random((N,C,H,W)) for i in range(N): for j in range(C): for k in range(H_out): for l in range(W_out): param = [0.0, 0.0] param[0] = (W_in - 1) * (grid[i][k][l][0] + 1) / 2 param[1] = (H_in - 1) * (grid[i][k][l][1] + 1) / 2 x0 = int(param[0]) x1 = x0 + 1 y0 = int(param[1]) y1 = y0 + 1 param[0] -= x0 param[1] -= y0 left_top = input[i][j][y0][x0] * (1 - param[0]) * (1 - param[1]) left_bottom = input[i][j][y1][x0] * (1 - param[0]) * param[1] right_top = input[i][j][y0][x1] * param[0] * (1 - param[1]) right_bottom = input[i][j][y1][x1] * param[0] * param[1] result = left_bottom + left_top + right_bottom + right_top output[i][j][k][l] = result return outputN, C, H, W = 1, 1, 4, 4input = np.random.random((N,C,H,W))grid = np.random.random((N,H,W,2))out = grid_sample(input, grid)print(f'自定义实现输出结果:\n{out}')input = torch.from_numpy(input)grid = torch.from_numpy(grid)output = torch.nn.functional.grid_sample(input,grid,mode='bilinear', padding_mode='zeros',align_corners=True)print(f'grid_sample输出结果:\n{output}')

运行结果:

         从输出结果上看,与pytorch基本一致,由于仅仅做简单验证,这里没有对超出[-1,1]范围的xy值做处理,只能处理四维input,五维input的实现思路与这里基本一致。

        考虑到(x,y)取值范围可能越界,pytorch中的padding_mode设置就是对(x,y)落在输入特征图外边缘情况进行处理,一般设置'zero',也就是对靠近输入特征图范围以外的采样点进行0填充,如果不进行处理显然会造成索引越界。要解决(x,y)越界问题,可以进行如下修改:

import torchimport numpy as npdef grid_sample(input, grid): N, C, H_in, W_in = input.shape N, H_out, W_out, _ = grid.shape output = np.random.random((N, C, H_out, W_out)) for i in range(N): for j in range(C): for k in range(H_out): for l in range(W_out): x, y = grid[i][k][l][0], grid[i][k][l][1] param = [0.0, 0.0] param[0] = (W_in - 1) * (x + 1) / 2 param[1] = (H_in - 1) * (y + 1) / 2 x1 = int(param[0] + 1) x0 = x1 - 1 y1 = int(param[1] + 1) y0 = y1 - 1 param[0] = abs(param[0] - x0) param[1] = abs(param[1] - y0) left_top_value, left_bottom_value, right_top_value, right_bottom_value = 0, 0, 0, 0 if 0 <= x0 < W_in and 0 <= y0 < H_in: left_top_value = input[i][j][y0][x0] if 0 <= x1 < W_in and 0 <= y0 < H_in: right_top_value = input[i][j][y0][x1] if 0 <= x0 < W_in and 0 <= y1 < H_in: left_bottom_value = input[i][j][y1][x0] if 0 <= x1 < W_in and 0 <= y1 < H_in: right_bottom_value = input[i][j][y1][x1] left_top = left_top_value * (1 - param[0]) * (1 - param[1]) left_bottom = left_bottom_value * (1 - param[0]) * param[1] right_top = right_top_value * param[0] * (1 - param[1]) right_bottom = right_bottom_value * param[0] * param[1] result = left_bottom + left_top + right_bottom + right_top output[i][j][k][l] = result return outputN, C, H_in, W_in = 1, 1, 4, 4H_out, W_out = 4, 4input = np.random.random((N, C, H_in, W_in))grid = np.random.random((N, H_out, W_out, 2))grid[0][0][0] = [-1.2, 1.3]out = grid_sample(input, grid)print(f'自定义实现输出结果:\n{out}')input = torch.from_numpy(input)grid = torch.from_numpy(grid)output = torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=True)print(f'grid_sample输出结果:\n{output}')

     测试结果:

   

本文链接地址:https://www.jiuchutong.com/zhishi/295209.html 转载请保留说明!

上一篇:成功解决:npm 版本不支持node.js。【 npm v9.1.2 does not support Node.js v16.6.0.】(成功解决冲突的能力英语)

下一篇:Vue 中 forEach() 的使用(vue foreach is not a function)

  • 关于教师节的感恩的话(关于教师节的感谢语)(关于教师节的感恩词)

    关于教师节的感恩的话(关于教师节的感谢语)(关于教师节的感恩词)

  • 抖音秒关注秒取消会知道么(抖音秒关注秒取消对方给我打招呼)

    抖音秒关注秒取消会知道么(抖音秒关注秒取消对方给我打招呼)

  • 文件被覆盖了还能恢复吗(文件被覆盖了还能找到吗?)

    文件被覆盖了还能恢复吗(文件被覆盖了还能找到吗?)

  • 猫是不是路由器(哪个是猫哪个是路由器)

    猫是不是路由器(哪个是猫哪个是路由器)

  • 苹果截图编辑怎么关闭(苹果截图编辑怎么打马赛克)

    苹果截图编辑怎么关闭(苹果截图编辑怎么打马赛克)

  • 腾讯会议锁屏会退出吗(腾讯会议锁屏会记录时长吗)

    腾讯会议锁屏会退出吗(腾讯会议锁屏会记录时长吗)

  • 弹幕突然没有了(这两天弹幕没了)

    弹幕突然没有了(这两天弹幕没了)

  • 华为mate20pro拍照教程是什么(华为mate20pro拍照得分)

    华为mate20pro拍照教程是什么(华为mate20pro拍照得分)

  • 微信上不了怎么弄回来(微信上不了怎么注销)

    微信上不了怎么弄回来(微信上不了怎么注销)

  • 快手发的视频怎么删除(快手发的视频怎么删除掉)

    快手发的视频怎么删除(快手发的视频怎么删除掉)

  • 华为手环一定要华为手机吗(华为手环一定要华为账号吗)

    华为手环一定要华为手机吗(华为手环一定要华为账号吗)

  • switch如何区分续航版(switch怎么区分续航)

    switch如何区分续航版(switch怎么区分续航)

  • 微信无缘无故被永久封号怎么办(微信无缘无故被投诉了、是怎么回事)

    微信无缘无故被永久封号怎么办(微信无缘无故被投诉了、是怎么回事)

  • 小米抖音看完整版怎么看(小米手机抖音为什么看不到完整版)

    小米抖音看完整版怎么看(小米手机抖音为什么看不到完整版)

  • 闲鱼我超赞的在哪里找(我的闲鱼超赞不见了)

    闲鱼我超赞的在哪里找(我的闲鱼超赞不见了)

  • 魅族16th有快充吗(魅族16th快充多少w)

    魅族16th有快充吗(魅族16th快充多少w)

  • 主机是指什么(手机主机是指什么)

    主机是指什么(手机主机是指什么)

  • 华为p30可以登录几个微信(华为p30登录qq怎么显示)

    华为p30可以登录几个微信(华为p30登录qq怎么显示)

  • 华为mate30自带贴膜吗(华为mate30自带贴膜涩)

    华为mate30自带贴膜吗(华为mate30自带贴膜涩)

  • 淘宝详情页模板怎么做(淘宝详情页模板在哪里)

    淘宝详情页模板怎么做(淘宝详情页模板在哪里)

  • 电话免打扰在哪里设置(电话免打扰在哪里关闭)

    电话免打扰在哪里设置(电话免打扰在哪里关闭)

  • 小米5x遥控器在哪里(小米a55遥控器)

    小米5x遥控器在哪里(小米a55遥控器)

  • 华为手机怎么才算激活(华为手机怎么才能安装未知应用)

    华为手机怎么才算激活(华为手机怎么才能安装未知应用)

  • vivo怎么关闭sos紧急呼叫(vivo手机怎么关闭全局搜索)

    vivo怎么关闭sos紧急呼叫(vivo手机怎么关闭全局搜索)

  • 阿里卖家故意不发货(阿里卖家不发货怎么办)

    阿里卖家故意不发货(阿里卖家不发货怎么办)

  • vue2 vue-router 不显示页面问题

    vue2 vue-router 不显示页面问题

  • 总公司是小规模分公司是一般纳税人
  • 车险发票不含车船税怎么记账
  • 防伪税控风险纳税人财务负责人和法人同一人
  • 玉米大量收购
  • 没有抵扣的进项发票,开错了对方没有作废
  • 增值税专用发票税额怎么抵扣
  • 车船使用税应该交哪里的税
  • 拿到一个材料如何加工
  • 一般纳税人企业所得税政策最新2023税率
  • 审计费用收取标准的2020
  • 公司注销欠法人款怎么帐务处理?
  • 未开票收入可以填写负数吗
  • 酒店小规模纳税人税率
  • 营业外收入征企业所得税吗
  • 证券投资基金管理人的职权
  • 福利企业的增值税是多少
  • 金银首饰的消费税税务处理
  • 包工包料的工程怎么做账
  • 手动设定ip地址后连不上网
  • mac safari使用技巧
  • 电脑很空但是占用率90
  • mac 如何u盘启动
  • 只有高新技术企业能享受研发加计扣除吗
  • 金融资产包括哪三大类及会计科目
  • 坏账准备的方法
  • 查补以前年度收入
  • 公司帮员工买社保能扣税吗
  • 实际收到的货款怎么做账
  • 在建工程计提减值准备计入什么科目
  • php模板引擎原理
  • php array操作
  • thinkphp教程
  • php源码 数据库
  • 个体工商户个税优惠政策2023
  • 回扣,折扣和佣金都具有违法性对吗
  • 织梦如何添加浮动广告
  • sql2005安装不上
  • 数据库管理中负责数据模式定义的数据库语言是
  • 什么是全面一次性奖金
  • 开发票该怎么操作?
  • 以前年度损益影响当期损益吗
  • 无形资产的会计准则的相关规定
  • 商场联营方案
  • 公关费用计入什么科目比较好
  • 融资租赁承租方怎么做账
  • 转出未交增值税借方余额怎么处理
  • 为什么零售业只进不出呢
  • 股东变更需要哪些资料和手续
  • win7系统如何提升性能
  • win8系统如何安装软件
  • win8系统升级
  • win7系统如何更改默认浏览器
  • win8右下角
  • suse linux教程
  • win8垃圾清理
  • 建立一个新用户并把它加入wheel组,设置用户密码为123
  • 勒索病毒一般勒索多少钱
  • 用360可以装win7系统吗
  • minidump文件怎么打开
  • javascript怎么学
  • eclipse从本地导入项目
  • python算法具有哪五个性质
  • python中random模块用法
  • python socket编程步骤
  • javascript详细介绍
  • hbase获取所有表
  • jquery教程chm
  • javascriptz
  • 安卓通知栏管理工具
  • 房屋附属设备和配套设施计征房产税
  • 高新区税务局发工资时间
  • 税收分类分级管理后如何开展风险管理
  • 上海交电费户号8位数
  • 甘肃省契税征收标准
  • 陕西省税务电话是多少
  • 个人所得税完税证明在哪里查询打印
  • 区里的地税局局长是谁
  • 社保已生成单据如何作废上海
  • 西安税务局服务电话
  • 优化营商环境关于人才工作
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设