位置: IT常识 - 正文

pytorch 笔记:torch.distributions 概率分布相关(更新中)(pytorch torch)

编辑:rootadmin
pytorch 笔记:torch.distributions 概率分布相关(更新中) 1 包介绍

推荐整理分享pytorch 笔记:torch.distributions 概率分布相关(更新中)(pytorch torch),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch1.5对应的torchvision,pytorch with torch.nograd,pytorch torch,pytorch torch.load,pytorch torchscript,pytorch torch,pytorch torchvision,pytorch torchvision,内容如对您有帮助,希望把文章链接给更多的朋友!

        torch.distributions包包含可参数化的概率分布和采样函数。 这允许构建用于优化的随机计算图和随机梯度估计器。

        不可能通过随机样本直接反向传播。 但是,有两种主要方法可以创建可以反向传播的代理函数。

这些是

评分函数估计量 score function estimato似然比估计量 likelihood ratio estimatorREINFORCE路径导数估计量 pathwise derivative estimator

REINFORCE 通常被视为强化学习中策略梯度方法的基础,

路径导数估计器常见于变分自编码器的重新参数化技巧中。

        虽然评分函数只需要样本 f(x)的值,但路径导数需要导数 f'(x)。、

1.1 REINFORCE

        我们以reinforce 为例:

        当概率密度函数关于其参数可微时,我们只需要 sample() 和 log_prob() 来实现 REINFORCE:

        

        其中θ是参数,α是学习率,r是奖励,是在状态s的时候,根据策略使用动作a的概率

        (这个也就是policy gradient)

强化学习笔记:Policy-based Approach_UQI-LIUWJ的博客-CSDN博客

         在实践中,我们会从网络的输出中采样一个动作,在一个环境中应用这个动作,然后使用 log_prob 构造一个等效的损失函数。

         对于分类策略,实现 REINFORCE 的代码如下:(这只是一个示意代码,跑不起来的)

probs = policy_network(state)#在状态state的时候,各个action的概率m = Categorical(probs)#分类概率action = m.sample()#采样一个actionnext_state, reward = env.step(action)#这里为了简化考虑,一个episode只有一个actionloss = -m.log_prob(action) * reward#m.log_prob(action) 就是 logp#reward就是前面的r#这里用负号是因为强化学习是梯度上升loss.backward()  2 包所涉及的类2.1 伯努利分布

torch.distributions.bernoulli.Bernoulli( probs=None, logits=None, validate_args=None)

        创建由 probs 或 logits(但不是两者同时)参数化的伯努利分布。

        样本是二进制的(0 或 1)。 它们取值 1 的概率为 p,取值 0 的概率为 1 - p。

2.1.1 参数probs (Number,Tensor) 采样概率logits (Number,Tensor) 采样的对数几率2.1.2 函数 & 属性sample()

采样,默认采样一个值

还可以按照shape 采样

entropy()

计算熵

enumerate_support()

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

mean

均值

probs, logits两个输入的参数param_shape

参数的形状

variance

方差

2.2 贝塔分布torch.distributions.beta.Beta( concentration1, concentration0, validate_args=None)

由concentration 1 (α)和concentration 0 (β)参数化的 Beta 分布。

 2.2.1 函数采样

默认是采样一个值,也可以设置采样的维数

entropy

计算熵

rsample(sample_shape)pytorch 笔记:torch.distributions 概率分布相关(更新中)(pytorch torch)

如果分布参数是批处理的,则生成一个 sample_shape 形状的重新参数化样本或 sample_shape 形状的重新参数化样本批次。

注:生成Beta分布的时候,两个参数必须至少有一个是Tensor,否则rsample效果失效

mean,variance

均值 & 方差

 2.3 Chi2 分布torch.distributions.chi2.Chi2( df, validate_args=None)

 它只有sample一个函数 

2.4 连续伯努利

参数和伯努利很类似

torch.distributions.continuous_bernoulli.ContinuousBernoulli( probs=None, logits=None, lims=(0.499, 0.501), validate_args=None)

请注意,与伯努利不同,这里的“probs”不对应于伯努利的“probs”,这里的“logits”不对应于伯努利的“logits”,但由于与伯努利的相似性,使用了相同的名称。 

2.4.1 函数sample还是采样cdf

返回以 value 计算的累积概率密度函数。

icdf

返回以 value 计算的逆累积密度/质量函数。

entropy

还是计算熵

rsample

如果分布参数是批处理的,则生成一个 sample_shape 形状的重新参数化样本或 sample_shape 形状的重新参数化样本批次。

和前面Beta分布类似,只有创建时参数为Tensor,才会有rsample效果

mean,variance均值 方差 2.5 二项分布torch.distributions.binomial.Binomial( total_count=1, probs=None, logits=None, validate_args=None)

 

         创建由 total_count 和 probs 或 logits(但不是两者)参数化的二项分布。 total_count 必须可以用 probs/logits 广播。

2.5.1 函数&参数sample

采样

 

100被广播到0,0.2,0.8,1 所以每次相当于是四个二项分布

enumerate_support

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

mean,variance

均值,方差

2.6  分类分布torch.distributions.categorical.Categorical( probs=None, logits=None, validate_args=None)

 样本是来{0,...,K−1} 的整数,其中 K 是 probs.size(-1)。

2.6.1 函数sample采样entropy

enumerate_support

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

2.6.2 注意:

创建分类分布时候的Tensor中元素的和可以不是1,最后归一化到1即可

import torchimport mathm=torch.distributions.Categorical(torch.Tensor([1,2,4]))m.enumerate_support()#tensor([0, 1, 2])m.probs#tensor([0.1429, 0.2857, 0.5714])3 log_probs

很多分类都有这样一个函数log_probs,我们就统一说一下

假设m是一个torch的分类,那么m.log_prob(action)相当于

probs.log()[0][action.item()].unsqueeze(0)

(对这个action的概率添加log操作) 

import torchimport mathm=torch.distributions.Categorical(torch.Tensor([1,2,4]))m.enumerate_support()#tensor([0, 1, 2])a=m.sample()a#tensor(2)m.probs#tensor([0.1429, 0.2857, 0.5714])m.probs.log()#tensor([-1.9459, -1.2528, -0.5596])m.log_prob(a)#tensor(-0.5596)m.probs.log()[a.item()]#tensor(-0.5596)
本文链接地址:https://www.jiuchutong.com/zhishi/297674.html 转载请保留说明!

上一篇:Vue的安装及使用教程【超详细图文教程】(vue的安装步骤)

下一篇:Segment Anything Model (SAM)——卷起来了,那个号称分割一切的CV大模型他来了(segment anything model模型 需要的配置)

  • 微信怎么查看自己支付密码(微信怎么查看自己撤回的消息)

    微信怎么查看自己支付密码(微信怎么查看自己撤回的消息)

  • 小米音乐通知栏怎么关(小米音乐通知栏怎么开)

    小米音乐通知栏怎么关(小米音乐通知栏怎么开)

  • 苹果xs max怎么下载qq

    苹果xs max怎么下载qq

  • 微信号可以用什么符号(微信号可以用什么注册)

    微信号可以用什么符号(微信号可以用什么注册)

  • oppo手机触屏突然不能用了怎么办(oppo手机触屏突然失灵了什么原因)

    oppo手机触屏突然不能用了怎么办(oppo手机触屏突然失灵了什么原因)

  • qq扫一扫识别文字怎么没有了(qq扫描)

    qq扫一扫识别文字怎么没有了(qq扫描)

  • 抖音三个版本有什么区别(抖音几个版本有何不同)

    抖音三个版本有什么区别(抖音几个版本有何不同)

  • 微信视频通话内容能恢复吗

    微信视频通话内容能恢复吗

  • 微博怎么清空艾特记录(微博怎么清空艾特我的评论)

    微博怎么清空艾特记录(微博怎么清空艾特我的评论)

  • 无线网络图标不见了怎么办(无线网络图标不见了怎么办手机)

    无线网络图标不见了怎么办(无线网络图标不见了怎么办手机)

  • win7睡眠无法唤醒黑屏(w7睡眠无法唤醒)

    win7睡眠无法唤醒黑屏(w7睡眠无法唤醒)

  • 腾讯云服务器可以干嘛(腾讯云服务器可以转移账号吗)

    腾讯云服务器可以干嘛(腾讯云服务器可以转移账号吗)

  • 计算机被称为裸机是因为(在计算机领域中所说的裸机是指)

    计算机被称为裸机是因为(在计算机领域中所说的裸机是指)

  • 为什么qq音乐下载的歌不在本地(为什么qq音乐下载的音乐在文件夹找不到)

    为什么qq音乐下载的歌不在本地(为什么qq音乐下载的音乐在文件夹找不到)

  • 淘宝店铺淘气值在哪里(淘宝店铺淘气值怎么看)

    淘宝店铺淘气值在哪里(淘宝店铺淘气值怎么看)

  • word邮件合并在哪(word邮件合并在姓名后添加1号评委)

    word邮件合并在哪(word邮件合并在姓名后添加1号评委)

  • 苹果无线耳机接挂电话(苹果无线耳机接听电话怎么设置)

    苹果无线耳机接挂电话(苹果无线耳机接听电话怎么设置)

  • 苹果1200万像素清晰吗(苹果1200万像素是什么概念)

    苹果1200万像素清晰吗(苹果1200万像素是什么概念)

  • 苹果x日历不显示节假日(苹果x日历不显示父亲节)

    苹果x日历不显示节假日(苹果x日历不显示父亲节)

  • 小米手机怎么截屏长图(小米手机怎么截长图 滚动)

    小米手机怎么截屏长图(小米手机怎么截长图 滚动)

  • 荣耀手环3怎么关机(荣耀手环3怎么连接手机)

    荣耀手环3怎么关机(荣耀手环3怎么连接手机)

  • vue怎么连续添加字幕(vue绑定多个事件)

    vue怎么连续添加字幕(vue绑定多个事件)

  • 苹果电话怎么设置铃声(苹果电话怎么设置紧急联系人号码)

    苹果电话怎么设置铃声(苹果电话怎么设置紧急联系人号码)

  • cad提取坐标生成表格(cad提取坐标生成表格快捷键)

    cad提取坐标生成表格(cad提取坐标生成表格快捷键)

  • 【HTML5】调查问卷制作简约版(html调查问卷简单代码)

    【HTML5】调查问卷制作简约版(html调查问卷简单代码)

  • 企业所得税弥补亏损明细表怎么看
  • 广东通用机打发票可以抵扣吗
  • 经销商返点方案范文
  • 老板故意拖欠税款怎么办
  • 支出算什么会计科目
  • 对公跨行转账汇款(非柜面)手续费单笔 9折
  • 负数发票报税不让填怎么办
  • 累计所得税前净现金流量计算公式为
  • 总分机构移送固定资产是否缴增值税
  • 记账凭证按其适用的交易和事项分类可以分为
  • 房租怎么抵扣个税计算方法
  • 在建工程明细科目
  • 小型企业缴纳企业所得税
  • 付出的房屋押金可以退吗
  • 软件开发服务费会计分录
  • 分公司是否可以贷款
  • 个人所得税生产经营所得税率表2023
  • 增值税专用发票电子版
  • 非正常工资的个税是多少
  • 详细解读财税[2014]75号文件
  • 采用差额计税开什么发票
  • 企业一次性补助金是多少标准
  • 水电费没有发票怎么报销
  • 股东以个人名义签订租赁合同
  • 土地开发中三通一平
  • 验资报告需要什么材料
  • 物业公司收取的广告费开什么发票
  • 公司之间可以借款吗怎么做账
  • 企业废业实收资产怎么算
  • 银行保证金户利息计算公式
  • 普通发票红冲需要填信息表吗
  • 视同销售货物服务无形资产
  • 企业常用的成本核算方法有哪些
  • swiper.js常用功能
  • vue项目页面跳转
  • php面向对象优点,缺点
  • php 抓取别的网站的内容
  • 智能优化算法主要内容
  • $ajax请求
  • 我的年终总结怎么写
  • 面试宝典下载
  • 个税手续费会计分录
  • opengl 帧率
  • 代扣代缴境外增值税税率是多少
  • 发票上多盖了一个发票章
  • 增值税普通发票几个点
  • python动态参数应用
  • SQL Report Builder 报表里面的常见问题分析
  • 金蝶财务软件库存商品数量金额再那查看
  • 增值税冲红后附加税如何申报
  • 增值税申报表填错不影响税额
  • 增值税发票认证期限最新规定
  • sqlserver2008r2怎么使用
  • 增值税发票已抵扣怎么红冲
  • 小企业会计准则适用范围
  • 企业年金个人和公司缴费比例
  • 事业单位企业所得税汇算清缴怎么做
  • 不用鲁大师
  • vmware虚拟机不能识别iso
  • bios各项参数的意义
  • win7开机过程中黑屏
  • Linux服务器管理的开机界面
  • win10操作中心设置
  • zmweb.exe是什么进程
  • qq突然显示windows登录
  • 电脑xp系统虚拟内存不足怎么解决
  • linux文件解压gz
  • [android那些事] linux 下android源码编译(国内被墙方案)
  • jquery教程chm
  • 有趣的单机游戏
  • python搜索功能
  • jquery的选择器有哪几种类型
  • 安卓开发速成
  • python openfoam
  • 可以抵扣的消费税项目
  • 已经交完费还可以用医保卡报销吗
  • 图像信息采集照片
  • 西乡国税局电话
  • 关于船舶吨税的最新法律规范
  • 现在买新车都需要交什么费用
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设