位置: IT常识 - 正文

Pytorch中的grid_sample算子功能解析(pytorch中的数据类型)

编辑:rootadmin
Pytorch中的grid_sample算子功能解析

推荐整理分享Pytorch中的grid_sample算子功能解析(pytorch中的数据类型),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch中的loss函数,pytorch中的forward函数,pytorch中的数据类型,pytorch中的tensor,pytorch中的view函数,pytorch中的张量,pytorch中的loss函数,pytorch中的tensor,内容如对您有帮助,希望把文章链接给更多的朋友!

         pytorch中的grid_sample是一种特殊的采样算法。

调用接口为:

torch.nn.functional.grid_sample(input,grid,mode='bilinear',padding_mode='zeros',align_corners=None)。

         input参数是输入特征图tensor,也就是特征图,可以是四维或者五维张量,以四维形式为例(N,C,Hin,Win),N可以理解为Batch_size,C可以理解为通道数,Hin和Win也就是特征图高和宽。

         grid包含输出特征图特征图的格网大小以及每个格网对应到输入特征图的采样点位,对应四维input,其张量形式为(N,Hout,Wout,2),其中最后一维大小必须为2,如果输入为五维张量,那么最后一维大小必须为3。为什么最后一维必须为2或者3?因为grid的最后一个维度实际上代表一个坐标(x,y)或者(xy,z),对应到输入特征图的二维或三维特征图的坐标维度,xy取值范围一般为[-1,1],该范围映射到输入特征图的全图。

         mode为选择采样方法,有三种内插算法可选,分别是'bilinear'双线性差值、'nearest'最邻近插值、'bicubic' 双三次插值。

Pytorch中的grid_sample算子功能解析(pytorch中的数据类型)

         padding_mode为填充模式,即当(x,y)取值超过输入特征图采样范围,返回一个特定值,有'zeros' 、 'border' 、 'reflection'三种可选,一般用zero。

         align_corners为bool类型,指设定特征图坐标与特征值对应方式,设定为TRUE时,特征值位于像素中心。

         要理解grid_sample是如何工作的,最好就是进行简单的复现。假设输入shape为(N,C,H,W),grid的shape设定为(N,H,W,2),以双线性差值为例进行处理。首先根据input和grid设定,输出特征图tensor的shape为(N,C,H,W),输出特征图上每一个cell上的值由grid最后一维(x,y)确定。那么如何计算输出tensor上每一个点的值?首先,通过(x,y)找到输入特征图上的采样位置,由于xy取值范围为[-1,1],为了便于计算,先将xy取值范围调整为[0,1]。通过(w-1)*(x+1)/2、(wh-1)*(y+1)/2将xy映射为输入特征图的具体坐标位置。将xy映射到特征图实际坐标后,取该坐标附近四个角点特征值,通过四个特征值坐标与采样点坐标相对关系进行双线性插值,得到采样点的值。

注意:xy映射后的坐标可能是输入特征图上任意位置。假设输出特征图上(2,2)坐标位置上的值采样位置可能为输入特征图上(3,4)位置,xy越小越靠近输入特征图左上角,越大则越靠近右下角。

         基于上面的思路,可以进行一个简单的自定义实现。根据指定shape生成input和grid,使用pytorch中的grid_sample算子生成output。之后取grid中的第一个位置中的xy,根据xy从input中通过双线性插值计算出output第一个位置的值。

import torchimport numpy as npdef grid_sample(input, grid): N, C, H_in, W_in = input.shape N, H_out, W_out, _ = grid.shape output = np.random.random((N,C,H,W)) for i in range(N): for j in range(C): for k in range(H_out): for l in range(W_out): param = [0.0, 0.0] param[0] = (W_in - 1) * (grid[i][k][l][0] + 1) / 2 param[1] = (H_in - 1) * (grid[i][k][l][1] + 1) / 2 x0 = int(param[0]) x1 = x0 + 1 y0 = int(param[1]) y1 = y0 + 1 param[0] -= x0 param[1] -= y0 left_top = input[i][j][y0][x0] * (1 - param[0]) * (1 - param[1]) left_bottom = input[i][j][y1][x0] * (1 - param[0]) * param[1] right_top = input[i][j][y0][x1] * param[0] * (1 - param[1]) right_bottom = input[i][j][y1][x1] * param[0] * param[1] result = left_bottom + left_top + right_bottom + right_top output[i][j][k][l] = result return outputN, C, H, W = 1, 1, 4, 4input = np.random.random((N,C,H,W))grid = np.random.random((N,H,W,2))out = grid_sample(input, grid)print(f'自定义实现输出结果:\n{out}')input = torch.from_numpy(input)grid = torch.from_numpy(grid)output = torch.nn.functional.grid_sample(input,grid,mode='bilinear', padding_mode='zeros',align_corners=True)print(f'grid_sample输出结果:\n{output}')

运行结果:

         从输出结果上看,与pytorch基本一致,由于仅仅做简单验证,这里没有对超出[-1,1]范围的xy值做处理,只能处理四维input,五维input的实现思路与这里基本一致。

        考虑到(x,y)取值范围可能越界,pytorch中的padding_mode设置就是对(x,y)落在输入特征图外边缘情况进行处理,一般设置'zero',也就是对靠近输入特征图范围以外的采样点进行0填充,如果不进行处理显然会造成索引越界。要解决(x,y)越界问题,可以进行如下修改:

import torchimport numpy as npdef grid_sample(input, grid): N, C, H_in, W_in = input.shape N, H_out, W_out, _ = grid.shape output = np.random.random((N, C, H_out, W_out)) for i in range(N): for j in range(C): for k in range(H_out): for l in range(W_out): x, y = grid[i][k][l][0], grid[i][k][l][1] param = [0.0, 0.0] param[0] = (W_in - 1) * (x + 1) / 2 param[1] = (H_in - 1) * (y + 1) / 2 x1 = int(param[0] + 1) x0 = x1 - 1 y1 = int(param[1] + 1) y0 = y1 - 1 param[0] = abs(param[0] - x0) param[1] = abs(param[1] - y0) left_top_value, left_bottom_value, right_top_value, right_bottom_value = 0, 0, 0, 0 if 0 <= x0 < W_in and 0 <= y0 < H_in: left_top_value = input[i][j][y0][x0] if 0 <= x1 < W_in and 0 <= y0 < H_in: right_top_value = input[i][j][y0][x1] if 0 <= x0 < W_in and 0 <= y1 < H_in: left_bottom_value = input[i][j][y1][x0] if 0 <= x1 < W_in and 0 <= y1 < H_in: right_bottom_value = input[i][j][y1][x1] left_top = left_top_value * (1 - param[0]) * (1 - param[1]) left_bottom = left_bottom_value * (1 - param[0]) * param[1] right_top = right_top_value * param[0] * (1 - param[1]) right_bottom = right_bottom_value * param[0] * param[1] result = left_bottom + left_top + right_bottom + right_top output[i][j][k][l] = result return outputN, C, H_in, W_in = 1, 1, 4, 4H_out, W_out = 4, 4input = np.random.random((N, C, H_in, W_in))grid = np.random.random((N, H_out, W_out, 2))grid[0][0][0] = [-1.2, 1.3]out = grid_sample(input, grid)print(f'自定义实现输出结果:\n{out}')input = torch.from_numpy(input)grid = torch.from_numpy(grid)output = torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=True)print(f'grid_sample输出结果:\n{output}')

     测试结果:

   

本文链接地址:https://www.jiuchutong.com/zhishi/295209.html 转载请保留说明!

上一篇:成功解决:npm 版本不支持node.js。【 npm v9.1.2 does not support Node.js v16.6.0.】(成功解决冲突的能力英语)

下一篇:Vue 中 forEach() 的使用(vue foreach is not a function)

  • 2021全运会在哪看直播(2021全运会哪里举办)

    2021全运会在哪看直播(2021全运会哪里举办)

  • 荣耀v9是哪一年上市(荣耀 v9)(荣耀v9是哪一年生产的)

    荣耀v9是哪一年上市(荣耀 v9)(荣耀v9是哪一年生产的)

  • qq空间上传视频怎么不被压缩(qq空间上传视频怎么保持原画质)

    qq空间上传视频怎么不被压缩(qq空间上传视频怎么保持原画质)

  • 怎样查辅助验证次数(辅助验证进度查询)

    怎样查辅助验证次数(辅助验证进度查询)

  • 虚拟单和实物单的区别(什么是虚拟单,列举4个)

    虚拟单和实物单的区别(什么是虚拟单,列举4个)

  • 快手非法行为是属于几类(快手非法行为是怎么回事)

    快手非法行为是属于几类(快手非法行为是怎么回事)

  • a92s是5g手机吗(a92s是不是5g手机)

    a92s是5g手机吗(a92s是不是5g手机)

  • im服务器是什么意思(imap服务器是什么意思)

    im服务器是什么意思(imap服务器是什么意思)

  • 手机mtp是什么意思(手机的mtp在哪里设置)

    手机mtp是什么意思(手机的mtp在哪里设置)

  • 抖音拉黑了对方还能发信息吗(抖音拉黑了对方还能看到我直播吗?)

    抖音拉黑了对方还能发信息吗(抖音拉黑了对方还能看到我直播吗?)

  • 华为手机总有提示音怎么回事(华为手机总有提示音 而且还看不到)

    华为手机总有提示音怎么回事(华为手机总有提示音 而且还看不到)

  • 手机音乐怎么传到mp3(手机音乐怎么传到OTG)

    手机音乐怎么传到mp3(手机音乐怎么传到OTG)

  • iphone11发售时间(iphone15发售日期)

    iphone11发售时间(iphone15发售日期)

  • 怎么把照片缩小到1m以下(怎么把照片缩小到50k)

    怎么把照片缩小到1m以下(怎么把照片缩小到50k)

  • 美团怎么处理到店无房(美团订单怎么处理)

    美团怎么处理到店无房(美团订单怎么处理)

  • 1660ti配什么显示器(1660ti要配什么cpu)

    1660ti配什么显示器(1660ti要配什么cpu)

  • 文件视频怎么保存到相册(文件视频怎么保存到电脑)

    文件视频怎么保存到相册(文件视频怎么保存到电脑)

  • xhci模式打开还是关闭(xhci模式打开还是关闭 win7)

    xhci模式打开还是关闭(xhci模式打开还是关闭 win7)

  • qq陌生人管理在哪(新版qq陌生人管理在哪)

    qq陌生人管理在哪(新版qq陌生人管理在哪)

  • 小米9有人脸解锁吗(小米9pro人脸解锁)

    小米9有人脸解锁吗(小米9pro人脸解锁)

  • AMI主板清除CMOS恢复出厂BIOS设置方法图文教程(主板清除bios)

    AMI主板清除CMOS恢复出厂BIOS设置方法图文教程(主板清除bios)

  • icwtutor.exe是什么进程 有什么作用 icwtutor进程查询(.ico是什么文件)

    icwtutor.exe是什么进程 有什么作用 icwtutor进程查询(.ico是什么文件)

  • Http请求-hutool工具类的使用

    Http请求-hutool工具类的使用

  • seata注册nacos报错:nettyServer init error:ErrCode:400, ErrMsg:failed to req API:/api//nacos/v1/ns/instan(seata+nacos)

    seata注册nacos报错:nettyServer init error:ErrCode:400, ErrMsg:failed to req API:/api//nacos/v1/ns/instan(seata+nacos)

  • phpcms是框架吗(phpcms是什么框架)

    phpcms是框架吗(phpcms是什么框架)

  • 普通增值税税率多少
  • 药酒消费税计税依据
  • 预缴所得税报表填错啦,年报可以修改吗
  • 个体工商户开普票限额最新规定
  • 2019发票认证期限新规
  • 网银发工资怎么增员的
  • 出口未开票怎么会计处理
  • 向银行贷款买车 绿本要给银行吗
  • 预提利息属于费用吗
  • 企业间贴现手续怎么办理
  • 造价服务费收费标准计算器
  • 个人所得税违规怎么处理
  • 小规模纳税人可以收13%的专票吗?
  • 上海房产税如何退税
  • 所得税清算时坏账怎么算
  • 工程开发票备注栏必需要写吗?
  • 进项票一定要专票吗
  • 汇算清缴涉及到哪些科目的调整
  • 公司贷款评估费的做账
  • 购买不需安装的生产设备会计分录
  • 企业发生的哪些业务可以使用简易计税法
  • 小米电视连不上路由器怎么回事
  • 公司向法人借款有税务风险吗
  • 对公账户转库存现金对方科目怎么填
  • 银行的贷款怎么发放
  • 农产品专票可以开零税率吗
  • php最好的教程
  • php foo
  • 包装物交不交消费税
  • 预付款不退如何投诉
  • 跨年度坏账准备转回账务处理
  • 辅料分配方法
  • 增值税税率调整为13%的文件
  • avoid什么用法
  • php字符串比较大小
  • 带息应收票据会计处理
  • 劳务外包会计分录最新
  • 冲销暂估入账应该填什么凭证
  • 数据可视化分析
  • vs命令参数
  • php代码自动生成
  • 企业所得税申报更正怎么操作
  • python包发布
  • phpcms 用的是什么模板引擎
  • 管理费用属于产品成本么
  • 完税证明能作为抵扣凭证吗
  • 背书是什么含义
  • 减免的附加税要申报吗
  • 棚户区改造国家给政府拨款吗
  • 企业管理理费包括哪些
  • 房地产开发企业会计制度
  • 外聘人员差旅费用无票调增
  • 一个企业只有收入怎么办
  • 政府补贴递延收益的摊销时间
  • 中药饮片盘点损耗率 法律
  • 在建工程预付款授信
  • 企业会计档案由谁保管
  • 公司为员工租房应注意
  • 如何科学设置运动负荷
  • 填写记账凭证内容摘要的三个要素
  • xp系统没有安装好,请重新运行安装程序
  • Win10 Mobile 10572快速配置更新推送 Win10 Mobile 10572升级体验
  • srvc32.exe - srvc32是什么进程
  • win7浏览器怎么升级到最新版
  • linux文件压缩和备份实验
  • unity系统错误
  • jquery怎么禁用按钮
  • bootstrap页头
  • eclipse开发安卓app实例
  • shell while 小于
  • Python 数据清洗
  • python 列表排序 中文
  • Python3.6正式版新特性预览
  • 基于javascript的毕业设计
  • pythontrutle
  • 支付给境外的特许权使用费
  • 广州市地方税务局官网
  • 西安市港务区属于哪个街道办
  • 融资租赁公司购进车辆账务处理
  • 北京地税查询官网
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设