位置: IT常识 - 正文

pytorch 笔记:torch.distributions 概率分布相关(更新中)(pytorch torch)

编辑:rootadmin
pytorch 笔记:torch.distributions 概率分布相关(更新中) 1 包介绍

推荐整理分享pytorch 笔记:torch.distributions 概率分布相关(更新中)(pytorch torch),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch1.5对应的torchvision,pytorch with torch.nograd,pytorch torch,pytorch torch.load,pytorch torchscript,pytorch torch,pytorch torchvision,pytorch torchvision,内容如对您有帮助,希望把文章链接给更多的朋友!

        torch.distributions包包含可参数化的概率分布和采样函数。 这允许构建用于优化的随机计算图和随机梯度估计器。

        不可能通过随机样本直接反向传播。 但是,有两种主要方法可以创建可以反向传播的代理函数。

这些是

评分函数估计量 score function estimato似然比估计量 likelihood ratio estimatorREINFORCE路径导数估计量 pathwise derivative estimator

REINFORCE 通常被视为强化学习中策略梯度方法的基础,

路径导数估计器常见于变分自编码器的重新参数化技巧中。

        虽然评分函数只需要样本 f(x)的值,但路径导数需要导数 f'(x)。、

1.1 REINFORCE

        我们以reinforce 为例:

        当概率密度函数关于其参数可微时,我们只需要 sample() 和 log_prob() 来实现 REINFORCE:

        

        其中θ是参数,α是学习率,r是奖励,是在状态s的时候,根据策略使用动作a的概率

        (这个也就是policy gradient)

强化学习笔记:Policy-based Approach_UQI-LIUWJ的博客-CSDN博客

         在实践中,我们会从网络的输出中采样一个动作,在一个环境中应用这个动作,然后使用 log_prob 构造一个等效的损失函数。

         对于分类策略,实现 REINFORCE 的代码如下:(这只是一个示意代码,跑不起来的)

probs = policy_network(state)#在状态state的时候,各个action的概率m = Categorical(probs)#分类概率action = m.sample()#采样一个actionnext_state, reward = env.step(action)#这里为了简化考虑,一个episode只有一个actionloss = -m.log_prob(action) * reward#m.log_prob(action) 就是 logp#reward就是前面的r#这里用负号是因为强化学习是梯度上升loss.backward()  2 包所涉及的类2.1 伯努利分布

torch.distributions.bernoulli.Bernoulli( probs=None, logits=None, validate_args=None)

        创建由 probs 或 logits(但不是两者同时)参数化的伯努利分布。

        样本是二进制的(0 或 1)。 它们取值 1 的概率为 p,取值 0 的概率为 1 - p。

2.1.1 参数probs (Number,Tensor) 采样概率logits (Number,Tensor) 采样的对数几率2.1.2 函数 & 属性sample()

采样,默认采样一个值

还可以按照shape 采样

entropy()

计算熵

enumerate_support()

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

mean

均值

probs, logits两个输入的参数param_shape

参数的形状

variance

方差

2.2 贝塔分布torch.distributions.beta.Beta( concentration1, concentration0, validate_args=None)

由concentration 1 (α)和concentration 0 (β)参数化的 Beta 分布。

 2.2.1 函数采样

默认是采样一个值,也可以设置采样的维数

entropy

计算熵

rsample(sample_shape)pytorch 笔记:torch.distributions 概率分布相关(更新中)(pytorch torch)

如果分布参数是批处理的,则生成一个 sample_shape 形状的重新参数化样本或 sample_shape 形状的重新参数化样本批次。

注:生成Beta分布的时候,两个参数必须至少有一个是Tensor,否则rsample效果失效

mean,variance

均值 & 方差

 2.3 Chi2 分布torch.distributions.chi2.Chi2( df, validate_args=None)

 它只有sample一个函数 

2.4 连续伯努利

参数和伯努利很类似

torch.distributions.continuous_bernoulli.ContinuousBernoulli( probs=None, logits=None, lims=(0.499, 0.501), validate_args=None)

请注意,与伯努利不同,这里的“probs”不对应于伯努利的“probs”,这里的“logits”不对应于伯努利的“logits”,但由于与伯努利的相似性,使用了相同的名称。 

2.4.1 函数sample还是采样cdf

返回以 value 计算的累积概率密度函数。

icdf

返回以 value 计算的逆累积密度/质量函数。

entropy

还是计算熵

rsample

如果分布参数是批处理的,则生成一个 sample_shape 形状的重新参数化样本或 sample_shape 形状的重新参数化样本批次。

和前面Beta分布类似,只有创建时参数为Tensor,才会有rsample效果

mean,variance均值 方差 2.5 二项分布torch.distributions.binomial.Binomial( total_count=1, probs=None, logits=None, validate_args=None)

 

         创建由 total_count 和 probs 或 logits(但不是两者)参数化的二项分布。 total_count 必须可以用 probs/logits 广播。

2.5.1 函数&参数sample

采样

 

100被广播到0,0.2,0.8,1 所以每次相当于是四个二项分布

enumerate_support

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

mean,variance

均值,方差

2.6  分类分布torch.distributions.categorical.Categorical( probs=None, logits=None, validate_args=None)

 样本是来{0,...,K−1} 的整数,其中 K 是 probs.size(-1)。

2.6.1 函数sample采样entropy

enumerate_support

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

2.6.2 注意:

创建分类分布时候的Tensor中元素的和可以不是1,最后归一化到1即可

import torchimport mathm=torch.distributions.Categorical(torch.Tensor([1,2,4]))m.enumerate_support()#tensor([0, 1, 2])m.probs#tensor([0.1429, 0.2857, 0.5714])3 log_probs

很多分类都有这样一个函数log_probs,我们就统一说一下

假设m是一个torch的分类,那么m.log_prob(action)相当于

probs.log()[0][action.item()].unsqueeze(0)

(对这个action的概率添加log操作) 

import torchimport mathm=torch.distributions.Categorical(torch.Tensor([1,2,4]))m.enumerate_support()#tensor([0, 1, 2])a=m.sample()a#tensor(2)m.probs#tensor([0.1429, 0.2857, 0.5714])m.probs.log()#tensor([-1.9459, -1.2528, -0.5596])m.log_prob(a)#tensor(-0.5596)m.probs.log()[a.item()]#tensor(-0.5596)
本文链接地址:https://www.jiuchutong.com/zhishi/297674.html 转载请保留说明!

上一篇:Vue的安装及使用教程【超详细图文教程】(vue的安装步骤)

下一篇:Segment Anything Model (SAM)——卷起来了,那个号称分割一切的CV大模型他来了(segment anything model模型 需要的配置)

  • 教你写团购推广方案(如何做团购推广文案)

    教你写团购推广方案(如何做团购推广文案)

  • 拼多多怎么删除银行卡(拼多多怎么删除评价)

    拼多多怎么删除银行卡(拼多多怎么删除评价)

  • 联想k3蓝牙音箱如何串联(联想k3蓝牙音箱说明书)

    联想k3蓝牙音箱如何串联(联想k3蓝牙音箱说明书)

  • 拼多多极速退款怎么关闭(拼多多极速退款后卖家拒收退货,必须还款吗)

    拼多多极速退款怎么关闭(拼多多极速退款后卖家拒收退货,必须还款吗)

  • 手机充不进去电是怎么回事(手机充不进去电但是显示在充电)

    手机充不进去电是怎么回事(手机充不进去电但是显示在充电)

  • 微信强提醒怎么设置(微信强提醒怎么没有了)

    微信强提醒怎么设置(微信强提醒怎么没有了)

  • 抖音里能不能看到访客(抖音里能不能看见谁来过)

    抖音里能不能看到访客(抖音里能不能看见谁来过)

  • 笔记本电脑小数点是哪个键(笔记本电脑小数点符号在哪里)

    笔记本电脑小数点是哪个键(笔记本电脑小数点符号在哪里)

  • 如何下载别人的快手视频(如何下载别人的抖音作品)

    如何下载别人的快手视频(如何下载别人的抖音作品)

  • ppt怎么设置艺术字位置(ppt怎么设置艺术效果)

    ppt怎么设置艺术字位置(ppt怎么设置艺术效果)

  • 抖音群里最多可以加入多少人(抖音群里最多可加多少人)

    抖音群里最多可以加入多少人(抖音群里最多可加多少人)

  • 微信如何禁言一个人(微信如何禁言一个人不让他知道)

    微信如何禁言一个人(微信如何禁言一个人不让他知道)

  • 快手怎么设置头像挂件?(快手怎么设置头像的头环)

    快手怎么设置头像挂件?(快手怎么设置头像的头环)

  • 键盘上句号是哪一个键(键盘上的句号是哪个键)

    键盘上句号是哪一个键(键盘上的句号是哪个键)

  • 电脑微信和手机微信能不同时在线吗(电脑微信和手机微信可以不同步吗)

    电脑微信和手机微信能不同时在线吗(电脑微信和手机微信可以不同步吗)

  • 淘宝评价管理已处理评价是什么意思(淘宝评价管理已处理的评价怎么回事)

    淘宝评价管理已处理评价是什么意思(淘宝评价管理已处理的评价怎么回事)

  • 天猫垫付账户在哪(天猫垫付账户是什么意思)

    天猫垫付账户在哪(天猫垫付账户是什么意思)

  • 小米9手机怎么截屏(小米9手机怎么截长图)

    小米9手机怎么截屏(小米9手机怎么截长图)

  • 荣耀20pro支持红外吗(荣耀20pro支持红外功能吗)

    荣耀20pro支持红外吗(荣耀20pro支持红外功能吗)

  • 美团怎么看全年消费(怎么看在美团每年消费总额)

    美团怎么看全年消费(怎么看在美团每年消费总额)

  • 苹果通知栏怎么自定义(苹果通知栏怎么突然打不开了)

    苹果通知栏怎么自定义(苹果通知栏怎么突然打不开了)

  • 魅族新智能冻结3.0s干嘛的(魅族智能冻结3.0怎样关闭)

    魅族新智能冻结3.0s干嘛的(魅族智能冻结3.0怎样关闭)

  • 拼多多怎么关闭自动连抽(拼多多怎么关闭别人看到我买的东西)

    拼多多怎么关闭自动连抽(拼多多怎么关闭别人看到我买的东西)

  • uc浏览器如何进入阅读模式(uc浏览器如何进入网页)

    uc浏览器如何进入阅读模式(uc浏览器如何进入网页)

  • 帝国cms使用js+css实现当前栏目高亮效果的方法(帝国cms使用redis)

    帝国cms使用js+css实现当前栏目高亮效果的方法(帝国cms使用redis)

  • 清税证明是什么要钱吗
  • 收到失业保险稳岗补贴会计分录
  • 代扣代缴增值税怎么做账
  • 汇算清缴交了税怎么做分录
  • 发票能加盖公章吗
  • 委托邮政企业投诉电话
  • 增值税组成计税价格包括消费税吗
  • 每个月0申报,对企业有什么影响吗?
  • 工会签约有什么好处
  • 农产品销售发票可以抵扣吗?
  • 个人互换住房土地增值税
  • 劳务分包预缴税怎么算
  • 顾客退货补差价怎么做账?
  • 已认证发票退回的流程
  • 去年的税还能退吗
  • 代开发票有哪些问题需要知道的?
  • 税收征收管理法
  • 免征增值税的进项税额如何处理
  • 工会经费的开支必须取得发票么
  • 开户许可证复印件是什么
  • win11系统关闭防火墙怎么关
  • 安装win7提示版本过低
  • PHP:diskfreespace()的用法_Filesystem函数
  • array php
  • 解决脱发的8个方法
  • 冲减应付账款如何做账
  • 怎么给复选框赋值
  • 无数据库cms
  • 母公司与子公司交易属于关联交易吗
  • 售后租回交易的资产销售价低于市场价承租人作为
  • 限定性净资产是指什么
  • 补交上年度的企业所得税税款计入什么科目
  • 怎么查看python
  • 桥接模式例子
  • 职工教育经费支出比例
  • pytest conftest
  • CentOS 6.5 x64系统中安装MongoDB 2.6.0二进制发行版教程
  • mongodb聚合函数详解
  • 兼职工资比正式工的工资高还是低
  • 固定资产折旧计入什么费用
  • 固定资产核销是资产损失吗
  • 应收账款逾期什么意思
  • 直接人工成本项目
  • 认证抵扣发票
  • 城市生活垃圾处理方式有哪几种
  • 过节费可以发现金吗
  • 收到进项发票不认证抵扣的会计分录
  • 分公司从业人数填越少越好吗
  • 出售二手固定资产如何开票
  • sql server获取字段长度
  • window怎么操作
  • 取消windows开机登录密码
  • win7 64位系统RAR压缩文件损坏该怎么修复
  • win8.1应用
  • macbook的dock栏怎么不见了
  • centos8设置默认启动命令界面
  • linux流量控制
  • 在linux系统中
  • win8使用教程和技能
  • unity固定位置随机生成物体
  • 微信小程序实现人脸识别
  • linux中makefile怎么写
  • node.js deno
  • NGUI之UIRoot
  • express框架的优缺点
  • python urljoin
  • 玩端游的平台
  • 超链接打开比较合适的方式是什么
  • android开发模式
  • 怎么运用知识点
  • Python中的多行注释文档编写风格汇总
  • 在linux安装python
  • 企业季度申报怎么报
  • 211学生占全国比例
  • 公司欠税款,还不上,怎么办
  • 衡水地税局税务电话
  • 汽车发票含税吗
  • 网上缴费如何开票
  • 税务网上抄报流程是什么
  • 西藏自治区税务局电子税务局
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设