位置: IT常识 - 正文

pytorch 笔记:torch.distributions 概率分布相关(更新中)(pytorch torch)

编辑:rootadmin
pytorch 笔记:torch.distributions 概率分布相关(更新中) 1 包介绍

推荐整理分享pytorch 笔记:torch.distributions 概率分布相关(更新中)(pytorch torch),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch1.5对应的torchvision,pytorch with torch.nograd,pytorch torch,pytorch torch.load,pytorch torchscript,pytorch torch,pytorch torchvision,pytorch torchvision,内容如对您有帮助,希望把文章链接给更多的朋友!

        torch.distributions包包含可参数化的概率分布和采样函数。 这允许构建用于优化的随机计算图和随机梯度估计器。

        不可能通过随机样本直接反向传播。 但是,有两种主要方法可以创建可以反向传播的代理函数。

这些是

评分函数估计量 score function estimato似然比估计量 likelihood ratio estimatorREINFORCE路径导数估计量 pathwise derivative estimator

REINFORCE 通常被视为强化学习中策略梯度方法的基础,

路径导数估计器常见于变分自编码器的重新参数化技巧中。

        虽然评分函数只需要样本 f(x)的值,但路径导数需要导数 f'(x)。、

1.1 REINFORCE

        我们以reinforce 为例:

        当概率密度函数关于其参数可微时,我们只需要 sample() 和 log_prob() 来实现 REINFORCE:

        

        其中θ是参数,α是学习率,r是奖励,是在状态s的时候,根据策略使用动作a的概率

        (这个也就是policy gradient)

强化学习笔记:Policy-based Approach_UQI-LIUWJ的博客-CSDN博客

         在实践中,我们会从网络的输出中采样一个动作,在一个环境中应用这个动作,然后使用 log_prob 构造一个等效的损失函数。

         对于分类策略,实现 REINFORCE 的代码如下:(这只是一个示意代码,跑不起来的)

probs = policy_network(state)#在状态state的时候,各个action的概率m = Categorical(probs)#分类概率action = m.sample()#采样一个actionnext_state, reward = env.step(action)#这里为了简化考虑,一个episode只有一个actionloss = -m.log_prob(action) * reward#m.log_prob(action) 就是 logp#reward就是前面的r#这里用负号是因为强化学习是梯度上升loss.backward()  2 包所涉及的类2.1 伯努利分布

torch.distributions.bernoulli.Bernoulli( probs=None, logits=None, validate_args=None)

        创建由 probs 或 logits(但不是两者同时)参数化的伯努利分布。

        样本是二进制的(0 或 1)。 它们取值 1 的概率为 p,取值 0 的概率为 1 - p。

2.1.1 参数probs (Number,Tensor) 采样概率logits (Number,Tensor) 采样的对数几率2.1.2 函数 & 属性sample()

采样,默认采样一个值

还可以按照shape 采样

entropy()

计算熵

enumerate_support()

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

mean

均值

probs, logits两个输入的参数param_shape

参数的形状

variance

方差

2.2 贝塔分布torch.distributions.beta.Beta( concentration1, concentration0, validate_args=None)

由concentration 1 (α)和concentration 0 (β)参数化的 Beta 分布。

 2.2.1 函数采样

默认是采样一个值,也可以设置采样的维数

entropy

计算熵

rsample(sample_shape)pytorch 笔记:torch.distributions 概率分布相关(更新中)(pytorch torch)

如果分布参数是批处理的,则生成一个 sample_shape 形状的重新参数化样本或 sample_shape 形状的重新参数化样本批次。

注:生成Beta分布的时候,两个参数必须至少有一个是Tensor,否则rsample效果失效

mean,variance

均值 & 方差

 2.3 Chi2 分布torch.distributions.chi2.Chi2( df, validate_args=None)

 它只有sample一个函数 

2.4 连续伯努利

参数和伯努利很类似

torch.distributions.continuous_bernoulli.ContinuousBernoulli( probs=None, logits=None, lims=(0.499, 0.501), validate_args=None)

请注意,与伯努利不同,这里的“probs”不对应于伯努利的“probs”,这里的“logits”不对应于伯努利的“logits”,但由于与伯努利的相似性,使用了相同的名称。 

2.4.1 函数sample还是采样cdf

返回以 value 计算的累积概率密度函数。

icdf

返回以 value 计算的逆累积密度/质量函数。

entropy

还是计算熵

rsample

如果分布参数是批处理的,则生成一个 sample_shape 形状的重新参数化样本或 sample_shape 形状的重新参数化样本批次。

和前面Beta分布类似,只有创建时参数为Tensor,才会有rsample效果

mean,variance均值 方差 2.5 二项分布torch.distributions.binomial.Binomial( total_count=1, probs=None, logits=None, validate_args=None)

 

         创建由 total_count 和 probs 或 logits(但不是两者)参数化的二项分布。 total_count 必须可以用 probs/logits 广播。

2.5.1 函数&参数sample

采样

 

100被广播到0,0.2,0.8,1 所以每次相当于是四个二项分布

enumerate_support

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

mean,variance

均值,方差

2.6  分类分布torch.distributions.categorical.Categorical( probs=None, logits=None, validate_args=None)

 样本是来{0,...,K−1} 的整数,其中 K 是 probs.size(-1)。

2.6.1 函数sample采样entropy

enumerate_support

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

2.6.2 注意:

创建分类分布时候的Tensor中元素的和可以不是1,最后归一化到1即可

import torchimport mathm=torch.distributions.Categorical(torch.Tensor([1,2,4]))m.enumerate_support()#tensor([0, 1, 2])m.probs#tensor([0.1429, 0.2857, 0.5714])3 log_probs

很多分类都有这样一个函数log_probs,我们就统一说一下

假设m是一个torch的分类,那么m.log_prob(action)相当于

probs.log()[0][action.item()].unsqueeze(0)

(对这个action的概率添加log操作) 

import torchimport mathm=torch.distributions.Categorical(torch.Tensor([1,2,4]))m.enumerate_support()#tensor([0, 1, 2])a=m.sample()a#tensor(2)m.probs#tensor([0.1429, 0.2857, 0.5714])m.probs.log()#tensor([-1.9459, -1.2528, -0.5596])m.log_prob(a)#tensor(-0.5596)m.probs.log()[a.item()]#tensor(-0.5596)
本文链接地址:https://www.jiuchutong.com/zhishi/297674.html 转载请保留说明!

上一篇:Vue的安装及使用教程【超详细图文教程】(vue的安装步骤)

下一篇:Segment Anything Model (SAM)——卷起来了,那个号称分割一切的CV大模型他来了(segment anything model模型 需要的配置)

  • 百词斩打卡截图(百词斩单词包)(百词斩打卡截图生成器)

    百词斩打卡截图(百词斩单词包)(百词斩打卡截图生成器)

  • 苹果13pro怎么关闭智能调整图像智能功能(苹果13pro怎么关闭自动更新)

    苹果13pro怎么关闭智能调整图像智能功能(苹果13pro怎么关闭自动更新)

  • 小爱助理通话怎么关闭(小爱助理通话怎么关闭 note8)

    小爱助理通话怎么关闭(小爱助理通话怎么关闭 note8)

  • 拍照怎么投屏到电视(怎样用手机投屏拍照)

    拍照怎么投屏到电视(怎样用手机投屏拍照)

  • 快手怎么没有创建付费内容(快手怎么没有创作者权益)

    快手怎么没有创建付费内容(快手怎么没有创作者权益)

  • 腾讯课堂后台播放老师知道吗

    腾讯课堂后台播放老师知道吗

  • 闲聊可以提现怎么到不了帐(闲聊能不能提现)

    闲聊可以提现怎么到不了帐(闲聊能不能提现)

  • ipad屏幕白斑会变大吗(ipad屏幕白色斑点)

    ipad屏幕白斑会变大吗(ipad屏幕白色斑点)

  • 宽带闪红灯怎么解决(宽带闪红灯怎么回事儿)

    宽带闪红灯怎么解决(宽带闪红灯怎么回事儿)

  • 苹果微信步数更新滞后(苹果微信步数更新)

    苹果微信步数更新滞后(苹果微信步数更新)

  • 无线适配器或访问点有问题未修复怎么解决(无线适配器或访问点有问题是什么意思)

    无线适配器或访问点有问题未修复怎么解决(无线适配器或访问点有问题是什么意思)

  • 打印机打印不清晰是什么原因(打印机打印不清晰怎么解决)

    打印机打印不清晰是什么原因(打印机打印不清晰怎么解决)

  • 光纤灯一直闪蓝色(光纤灯一直闪蓝光)

    光纤灯一直闪蓝色(光纤灯一直闪蓝光)

  • 垂直同步什么意思(垂直同步有什么用)

    垂直同步什么意思(垂直同步有什么用)

  • 大数据的意义包括(大数据的意义包括什么)

    大数据的意义包括(大数据的意义包括什么)

  • 探探解除匹配还能看到活跃时间吗(探探解除匹配还能看到朋友圈)

    探探解除匹配还能看到活跃时间吗(探探解除匹配还能看到朋友圈)

  • 手机怎么交座机话费(怎么在手机上给座机交话费)

    手机怎么交座机话费(怎么在手机上给座机交话费)

  • vue可以添加多张照片(vue怎么实现多页面)

    vue可以添加多张照片(vue怎么实现多页面)

  • 滴滴怎么绑定学生证(如何绑定滴滴打车平台)

    滴滴怎么绑定学生证(如何绑定滴滴打车平台)

  • 设置了拼多多极速发货如何取消(拼多多设置了极速发货怎么取消)

    设置了拼多多极速发货如何取消(拼多多设置了极速发货怎么取消)

  • 怎么给自己的照片加水印(怎么给自己的照片换衣服)

    怎么给自己的照片加水印(怎么给自己的照片换衣服)

  • 路由器的基本功能(路由器的基本功能是找到一条最佳的数据包传输路线)

    路由器的基本功能(路由器的基本功能是找到一条最佳的数据包传输路线)

  • 电脑系统出问题时如何减少损失?(电脑系统出问题了怎么办)

    电脑系统出问题时如何减少损失?(电脑系统出问题了怎么办)

  • 最新接口的固态硬盘是什么(固态硬盘最新接口)

    最新接口的固态硬盘是什么(固态硬盘最新接口)

  • 在建工程转固定资产需要交税吗
  • 滴滴电子普通发票怎么抵扣
  • 工商年报最迟什么时候申报
  • 民间非营利性组织收到个税手续费返还
  • 自建固定资产入账
  • 刷单成本计入什么费用?
  • 报关单不存在
  • 物业公司收入需要公示
  • 销售固定资产未收到钱
  • 处置交易性金融资产发生的交易费用
  • 汇算清缴所得税退回做账
  • 房产税和土地使用税计入什么科目
  • 国家税务金税四期
  • 关于土地增值税若干问题的通知
  • 制冷设备增值税税率
  • 进项税额加计抵减如何申报
  • 销售不动产预收款纳税义务发生时间
  • 门店关闭费用怎么处理
  • 固定资产入账包括税额吗
  • win11字体大小怎么调
  • word更改单页背景颜色
  • php require的用法
  • PHP:pg_select()的用法_PostgreSQL函数
  • 企业发生的诉讼费用
  • 企业规模扩大后更易于管理吗
  • 销售发票重复开,库存商品怎么处理?
  • phpajax技术
  • 土方工程公司账务实例
  • 金银首饰零售业税负率是多少
  • php 生成缩略图
  • php获取目录所有文件
  • php使用函数限制字符串长度和格式
  • yolov5目标检测流程图
  • groupinfo命令
  • 无形资产的转让
  • phpweb缓存技术
  • 购车哪些费用可以免
  • 办理外经证缴税流程
  • 认证发票可以分两次进行吗
  • 租车发票可以抵扣吗
  • Testing Applications with WebTest¶
  • dedecms使用教程
  • 总承包可以转包吗
  • 公允价值变动损益借贷方向增减
  • 工程产值是怎么算的
  • 综合单价税率调整如何结算
  • 金融资产经营资产
  • 企业给员工的福利体检报告
  • 预付款项属于什么会计要素
  • 母公司和子公司是两个完全独立的法人
  • 公司办理网银
  • 支付宝如何打印付款凭证
  • 企业预付账款的分录
  • 几个常见的收敛级数
  • 通过备份记录获取文件
  • sql server 数据
  • windows隐藏文件夹开启
  • linux sshd是什么
  • bios中英文对照表图新版
  • 怎么恢复Windows xp蓝天白云壁纸
  • 如何远程登录路由器
  • windows8开机蓝屏
  • centos fedora
  • Windows Server 2012服务器管理器的详细介绍
  • win10系统admin和oobe
  • win7累计更新补丁包
  • 在vs中搭建opengl环境
  • python的日志
  • cocos2d怎么用
  • 原生js制作日历软件
  • JavaScript instanceof 的使用方法示例介绍
  • linux 进程监控
  • unity常用代码
  • 税务备案超过15天
  • 广州税务举报电话
  • 银饰品交消费税吗
  • 个人外汇收入申报
  • 陕西省国税务局大企业处长邓谷祥简历
  • 税务局文化建设实施方案
  • 新加坡国税局的电话号码
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设