位置: IT常识 - 正文

Pytorch中的grid_sample算子功能解析(pytorch中的数据类型)

编辑:rootadmin
Pytorch中的grid_sample算子功能解析

推荐整理分享Pytorch中的grid_sample算子功能解析(pytorch中的数据类型),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch中的loss函数,pytorch中的forward函数,pytorch中的数据类型,pytorch中的tensor,pytorch中的view函数,pytorch中的张量,pytorch中的loss函数,pytorch中的tensor,内容如对您有帮助,希望把文章链接给更多的朋友!

         pytorch中的grid_sample是一种特殊的采样算法。

调用接口为:

torch.nn.functional.grid_sample(input,grid,mode='bilinear',padding_mode='zeros',align_corners=None)。

         input参数是输入特征图tensor,也就是特征图,可以是四维或者五维张量,以四维形式为例(N,C,Hin,Win),N可以理解为Batch_size,C可以理解为通道数,Hin和Win也就是特征图高和宽。

         grid包含输出特征图特征图的格网大小以及每个格网对应到输入特征图的采样点位,对应四维input,其张量形式为(N,Hout,Wout,2),其中最后一维大小必须为2,如果输入为五维张量,那么最后一维大小必须为3。为什么最后一维必须为2或者3?因为grid的最后一个维度实际上代表一个坐标(x,y)或者(xy,z),对应到输入特征图的二维或三维特征图的坐标维度,xy取值范围一般为[-1,1],该范围映射到输入特征图的全图。

         mode为选择采样方法,有三种内插算法可选,分别是'bilinear'双线性差值、'nearest'最邻近插值、'bicubic' 双三次插值。

Pytorch中的grid_sample算子功能解析(pytorch中的数据类型)

         padding_mode为填充模式,即当(x,y)取值超过输入特征图采样范围,返回一个特定值,有'zeros' 、 'border' 、 'reflection'三种可选,一般用zero。

         align_corners为bool类型,指设定特征图坐标与特征值对应方式,设定为TRUE时,特征值位于像素中心。

         要理解grid_sample是如何工作的,最好就是进行简单的复现。假设输入shape为(N,C,H,W),grid的shape设定为(N,H,W,2),以双线性差值为例进行处理。首先根据input和grid设定,输出特征图tensor的shape为(N,C,H,W),输出特征图上每一个cell上的值由grid最后一维(x,y)确定。那么如何计算输出tensor上每一个点的值?首先,通过(x,y)找到输入特征图上的采样位置,由于xy取值范围为[-1,1],为了便于计算,先将xy取值范围调整为[0,1]。通过(w-1)*(x+1)/2、(wh-1)*(y+1)/2将xy映射为输入特征图的具体坐标位置。将xy映射到特征图实际坐标后,取该坐标附近四个角点特征值,通过四个特征值坐标与采样点坐标相对关系进行双线性插值,得到采样点的值。

注意:xy映射后的坐标可能是输入特征图上任意位置。假设输出特征图上(2,2)坐标位置上的值采样位置可能为输入特征图上(3,4)位置,xy越小越靠近输入特征图左上角,越大则越靠近右下角。

         基于上面的思路,可以进行一个简单的自定义实现。根据指定shape生成input和grid,使用pytorch中的grid_sample算子生成output。之后取grid中的第一个位置中的xy,根据xy从input中通过双线性插值计算出output第一个位置的值。

import torchimport numpy as npdef grid_sample(input, grid): N, C, H_in, W_in = input.shape N, H_out, W_out, _ = grid.shape output = np.random.random((N,C,H,W)) for i in range(N): for j in range(C): for k in range(H_out): for l in range(W_out): param = [0.0, 0.0] param[0] = (W_in - 1) * (grid[i][k][l][0] + 1) / 2 param[1] = (H_in - 1) * (grid[i][k][l][1] + 1) / 2 x0 = int(param[0]) x1 = x0 + 1 y0 = int(param[1]) y1 = y0 + 1 param[0] -= x0 param[1] -= y0 left_top = input[i][j][y0][x0] * (1 - param[0]) * (1 - param[1]) left_bottom = input[i][j][y1][x0] * (1 - param[0]) * param[1] right_top = input[i][j][y0][x1] * param[0] * (1 - param[1]) right_bottom = input[i][j][y1][x1] * param[0] * param[1] result = left_bottom + left_top + right_bottom + right_top output[i][j][k][l] = result return outputN, C, H, W = 1, 1, 4, 4input = np.random.random((N,C,H,W))grid = np.random.random((N,H,W,2))out = grid_sample(input, grid)print(f'自定义实现输出结果:\n{out}')input = torch.from_numpy(input)grid = torch.from_numpy(grid)output = torch.nn.functional.grid_sample(input,grid,mode='bilinear', padding_mode='zeros',align_corners=True)print(f'grid_sample输出结果:\n{output}')

运行结果:

         从输出结果上看,与pytorch基本一致,由于仅仅做简单验证,这里没有对超出[-1,1]范围的xy值做处理,只能处理四维input,五维input的实现思路与这里基本一致。

        考虑到(x,y)取值范围可能越界,pytorch中的padding_mode设置就是对(x,y)落在输入特征图外边缘情况进行处理,一般设置'zero',也就是对靠近输入特征图范围以外的采样点进行0填充,如果不进行处理显然会造成索引越界。要解决(x,y)越界问题,可以进行如下修改:

import torchimport numpy as npdef grid_sample(input, grid): N, C, H_in, W_in = input.shape N, H_out, W_out, _ = grid.shape output = np.random.random((N, C, H_out, W_out)) for i in range(N): for j in range(C): for k in range(H_out): for l in range(W_out): x, y = grid[i][k][l][0], grid[i][k][l][1] param = [0.0, 0.0] param[0] = (W_in - 1) * (x + 1) / 2 param[1] = (H_in - 1) * (y + 1) / 2 x1 = int(param[0] + 1) x0 = x1 - 1 y1 = int(param[1] + 1) y0 = y1 - 1 param[0] = abs(param[0] - x0) param[1] = abs(param[1] - y0) left_top_value, left_bottom_value, right_top_value, right_bottom_value = 0, 0, 0, 0 if 0 <= x0 < W_in and 0 <= y0 < H_in: left_top_value = input[i][j][y0][x0] if 0 <= x1 < W_in and 0 <= y0 < H_in: right_top_value = input[i][j][y0][x1] if 0 <= x0 < W_in and 0 <= y1 < H_in: left_bottom_value = input[i][j][y1][x0] if 0 <= x1 < W_in and 0 <= y1 < H_in: right_bottom_value = input[i][j][y1][x1] left_top = left_top_value * (1 - param[0]) * (1 - param[1]) left_bottom = left_bottom_value * (1 - param[0]) * param[1] right_top = right_top_value * param[0] * (1 - param[1]) right_bottom = right_bottom_value * param[0] * param[1] result = left_bottom + left_top + right_bottom + right_top output[i][j][k][l] = result return outputN, C, H_in, W_in = 1, 1, 4, 4H_out, W_out = 4, 4input = np.random.random((N, C, H_in, W_in))grid = np.random.random((N, H_out, W_out, 2))grid[0][0][0] = [-1.2, 1.3]out = grid_sample(input, grid)print(f'自定义实现输出结果:\n{out}')input = torch.from_numpy(input)grid = torch.from_numpy(grid)output = torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=True)print(f'grid_sample输出结果:\n{output}')

     测试结果:

   

本文链接地址:https://www.jiuchutong.com/zhishi/295209.html 转载请保留说明!

上一篇:成功解决:npm 版本不支持node.js。【 npm v9.1.2 does not support Node.js v16.6.0.】(成功解决冲突的能力英语)

下一篇:Vue 中 forEach() 的使用(vue foreach is not a function)

  • 华为平板输入法怎么设置(华为平板输入法怎么切换中英文)

    华为平板输入法怎么设置(华为平板输入法怎么切换中英文)

  • 微信红包封面在哪里查看(微信红包封面在哪里兑换)

    微信红包封面在哪里查看(微信红包封面在哪里兑换)

  • 酷狗怎么把音乐下到U盘里(酷狗怎么把音乐下载到u盘)

    酷狗怎么把音乐下到U盘里(酷狗怎么把音乐下载到u盘)

  • 京东上门取件怎么修改地址(京东上门取件怎么下单)

    京东上门取件怎么修改地址(京东上门取件怎么下单)

  • 抖音获得铁粉标志的方法是什么(抖音铁粉标签是怎么样产生的)

    抖音获得铁粉标志的方法是什么(抖音铁粉标签是怎么样产生的)

  • win7怎么升级win10(win7怎么升级win10系统版本2022)

    win7怎么升级win10(win7怎么升级win10系统版本2022)

  • 华为mate20可不可以插内存卡(华为mate20可不可以反向充电)

    华为mate20可不可以插内存卡(华为mate20可不可以反向充电)

  • 华为荣耀20青春版防不防水(华为荣耀20青春版多少钱)

    华为荣耀20青春版防不防水(华为荣耀20青春版多少钱)

  • 十进制0.6875转换为二进制(十进制0.6875转换为二进制怎么算)

    十进制0.6875转换为二进制(十进制0.6875转换为二进制怎么算)

  • 全站仪hr和dhr是什么意思(全站仪hr hd)

    全站仪hr和dhr是什么意思(全站仪hr hd)

  • 更新数据库的查询称为(更新数据库的查询命令)

    更新数据库的查询称为(更新数据库的查询命令)

  • 手机息屏后断网怎么办(为什么手机关屏之后网络自己断了)

    手机息屏后断网怎么办(为什么手机关屏之后网络自己断了)

  • 闪聊为什么停服(闪聊什么时候重新开放)

    闪聊为什么停服(闪聊什么时候重新开放)

  • 小米手机下面的三个键怎么设置不见了(小米手机下面的返回键怎么设置出来)

    小米手机下面的三个键怎么设置不见了(小米手机下面的返回键怎么设置出来)

  • 剪映卡顿怎么办(剪映卡顿怎么办啊)

    剪映卡顿怎么办(剪映卡顿怎么办啊)

  • 申请电脑直播权限要多久(申请电脑直播权限)

    申请电脑直播权限要多久(申请电脑直播权限)

  • 华为lnd一al40是什么型号(华为lnd一al40是什么型号手机)

    华为lnd一al40是什么型号(华为lnd一al40是什么型号手机)

  • nova6 mate30区别(华为nova6和mate30哪个值得入手)

    nova6 mate30区别(华为nova6和mate30哪个值得入手)

  • 抖音会员有什么权利(抖音会员有什么条件)

    抖音会员有什么权利(抖音会员有什么条件)

  • 固态nvme需要开ahci吗(nvme固态和sata固态开机速度对比)

    固态nvme需要开ahci吗(nvme固态和sata固态开机速度对比)

  • 为什么苹果天气显示不出来(为什么苹果天气定位是别的地方)

    为什么苹果天气显示不出来(为什么苹果天气定位是别的地方)

  • 爱奇艺下载的视频在哪里(爱奇艺下载的视频怎么传到u盘)

    爱奇艺下载的视频在哪里(爱奇艺下载的视频怎么传到u盘)

  • ios13暗夜模式省电吗(苹果暗夜模式费电吗)

    ios13暗夜模式省电吗(苹果暗夜模式费电吗)

  • iphone11反向无线充电怎么使用(iphone11反向无线充电)

    iphone11反向无线充电怎么使用(iphone11反向无线充电)

  • airpods刻字能退吗(airpods刻字可以退吗)

    airpods刻字能退吗(airpods刻字可以退吗)

  • 路由器sn码有什么用(无线路由器sn啥意思)

    路由器sn码有什么用(无线路由器sn啥意思)

  • pe怎么装系统(Pe怎么装系统)

    pe怎么装系统(Pe怎么装系统)

  • b站如何取消挂件(b站如何关掉)

    b站如何取消挂件(b站如何关掉)

  • 数据类型转换(数据类型转换分为哪两种)

    数据类型转换(数据类型转换分为哪两种)

  • ps怎么把图片套入模板(ps怎么把图片套入样机快捷键)

    ps怎么把图片套入模板(ps怎么把图片套入样机快捷键)

  • 进货开了发票也写了购销合同要交印花税吗?
  • 广告宣传费扣除比例
  • 存货被盗的会计分录
  • 差额增税可以抵扣吗
  • 资源税折算后计提怎么算
  • 复式记账的优点有哪些
  • 亏损企业捐赠支出怎么算
  • 出售使用过的生产设备
  • 营改增后还有营业费用吗
  • 其他综合收益是利润表项目吗
  • 税务管理相关知识
  • 社保基数与个税缴纳基数一致
  • 企业间借贷利息如何入账
  • 购买固定资产的进口关税
  • 防伪税控开票系统SOAP服务端
  • 港口建设费2021年归国家税务总局
  • 生产性生物资产折旧年限
  • 支付青苗补偿费怎么做账
  • 工程项目管理人员任命书
  • 发票丢失登报声明怎么写
  • 电费返还怎么查询
  • 工会经费的优惠政策2020
  • 不动产进项税为什么不能抵扣
  • 工资扣员工的罚款入什么科目
  • 民间非营利组织会计制度最新版
  • configureandwatch
  • 苹果手机录音怎么转换成mp3格式
  • 电脑下载的文件打不开怎么回事
  • 原材料盘亏计入
  • 代开专票计提附加税吗
  • 航天信息服务费是什么费用
  • 谈谈你对人民美好生活的理解
  • ryzen3 2200g相当于i几
  • sguard是什么
  • 电脑故障检测与维护方法
  • 收到现金股利会引起什么变化
  • mom.exe是什么
  • 如何清理电脑浏览器
  • vue的安装命令
  • 零售业的进货帐务怎么做
  • 库存股属于什么类账户
  • php消息实时推送完整示例
  • 土地出让金返还的税务处理
  • 为什么说网络安全靠人民
  • cmd more命令
  • 股权对价支付
  • 快递明细单
  • hashmap的使用场景
  • 建筑业成本核算流程
  • 固定资产以什么资金形态存在
  • 税金及附加包括哪些科目
  • 坏账准备怎么结转到本年利润
  • 制造费用工资计入什么科目
  • 确认收入未开发票
  • 低值易耗品一次性摊销会计科目
  • 失业保险金的支付方式
  • 教育费附加计入其他应付款吗
  • 购买办公软件的进项发票可以抵扣吗
  • windows下启动mysql的命令是什么
  • mysql中count(), group by, order by使用详解
  • linux git教程
  • centos更新yum update
  • ubuntu20怎么连接蓝牙鼠标
  • linux的run目录放什么文件
  • win8怎么彻底删除软件
  • 个人pc用户免费下载软件
  • win7报错0x0000007b
  • win8初始登录账号密码
  • redhat linux8
  • 使用灭火器人要站在上风口还是下风口
  • Cocos2dx 3.2 + vs2012 + win7 改变面黑色背景的大小
  • javascript中
  • python中文分词代码
  • android开发工程师岗位说明
  • 关于python中的判断条件
  • android实现选择题模式
  • 安徽省电子税务局怎么下载
  • 四川国税网上办税
  • 举报电话12345管用吗
  • 金税盘读取发票
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设