位置: IT常识 - 正文

【Pytorch深度学习实战】(11)变分自动编码器(VAE)

编辑:rootadmin
【Pytorch深度学习实战】(11)变分自动编码器(VAE)

推荐整理分享【Pytorch深度学习实战】(11)变分自动编码器(VAE),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

 🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎

📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃

🎁欢迎各位→点赞👍 + 收藏⭐️ + 留言📝​

📣系列专栏 - 机器学习【ML】 自然语言处理【NLP】  深度学习【DL】

 🖍foreword

✔说明⇢本人讲解主要包括Python、机器学习(ML)、深度学习(DL)、自然语言处理(NLP)等内容。

如果你对这个系列感兴趣的话,可以关注订阅哟👋

Variational AutoEncoder(VAE)原理

传统的自编码器模型主要由两部分构成:编码器(encoder)和解码器(decoder)。如下图所示:

在上面的模型中,经过反复训练,我们的输入数据X最终被转化为一个编码向量X’, 其中X’的每个维度表示一些学到的关于数据的特征,而X’在每个维度上的取值代表X在该特征上的表现。随后,解码器网络接收X’的这些值并尝试重构原始输入。

举一个例子来加深大家对自编码器的理解:

【Pytorch深度学习实战】(11)变分自动编码器(VAE)

假设任何人像图片都可以由表情、肤色、性别、发型等几个特征的取值来唯一确定,那么我们将一张人像图片输入自动编码器后将会得到这张图片在表情、肤色等特征上的取值的向量X’,而后解码器将会根据这些特征的取值重构出原始输入的这张人像图片。

在上面的示例中,我们使用单个值来描述输入图像在潜在特征上的表现。但在实际情况中,我们可能更多时候倾向于将每个潜在特征表示为可能值的范围。例如,如果输入蒙娜丽莎的照片,将微笑特征设定为特定的单值(相当于断定蒙娜丽莎笑了或者没笑)显然不如将微笑特征设定为某个取值范围(例如将微笑特征设定为x到y范围内的某个数,这个范围内既有数值可以表示蒙娜丽莎笑了又有数值可以表示蒙娜丽莎没笑)更合适。而变分自编码器便是用“取值的概率分布”代替原先的单值来描述对特征的观察的模型,如下图的右边部分所示,经过变分自编码器的编码,每张图片的微笑特征不再是自编码器中的单值而是一个概率分布。

通过这种方法,我们现在将给定输入的每个潜在特征表示为概率分布。当从潜在状态解码时,我们将从每个潜在状态分布中随机采样,生成一个向量作为解码器模型的输入。

通过上述的编解码过程,我们实质上实施了连续,平滑的潜在空间表示。对于潜在分布的所有采样,我们期望我们的解码器模型能够准确重构输入。因此,在潜在空间中彼此相邻的值应该与非常类似的重构相对应。

以上便是变分自编码器构造所依据的原理,我们再来看一看它的具体结构。

如上图所示,与自动编码器由编码器与解码器两部分构成相似,VAE利用两个神经网络建立两个概率密度分布模型:一个用于原始输入数据的变分推断,生成隐变量的变分概率分布,称为推断网络;另一个根据生成的隐变量变分概率分布,还原生成原始数据的近似概率分布,称为生成网络。

假设原始数据集为

,每个数据样本 xi 都是随机产生的相互独立、连续或离散的分布变量,生成数据集合为

,并且假设该过程产生隐变量Z ,即Z是决定X属性的神秘原因(特征)。其中可观测变量X 是一个高维空间的随机向量,不可观测变量 Z 是一个相对低维空间的随机向量,该生成模型可以分成两个过程:

(1)隐变量 Z 后验分布的近似推断过程:

,即推断网络。

(2)生成变量X' 的条件分布生成过程:

,即生成网络。

尽管VAE 整体结构与自编码器AE 结构类似,但VAE 的作用原理和AE 的作用原理完全不同,VAE 的“编码器”和“解码器” 的输出都是受参数约束变量的概率密度分布,而不是某种特定的编码。

变分自编码器Pytorch的实现import osimport torchimport torch.nn as nnimport torch.nn.functional as Fimport torchvisionfrom torchvision import transformsfrom torchvision.utils import save_image# 设备配置device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 如果不存在则创建目录sample_dir = 'samples'if not os.path.exists(sample_dir): os.makedirs(sample_dir)# 超参数image_size = 784h_dim = 400z_dim = 20num_epochs = 15batch_size = 128learning_rate = 1e-3# MNIST 数据集dataset = torchvision.datasets.MNIST(root='../../data', train=True, transform=transforms.ToTensor(), download=True)# 数据加载器data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)# VAE模型class VAE(nn.Module): def __init__(self, image_size=784, h_dim=400, z_dim=20): super(VAE, self).__init__() self.fc1 = nn.Linear(image_size, h_dim) self.fc2 = nn.Linear(h_dim, z_dim) self.fc3 = nn.Linear(h_dim, z_dim) self.fc4 = nn.Linear(z_dim, h_dim) self.fc5 = nn.Linear(h_dim, image_size) def encode(self, x): h = F.relu(self.fc1(x)) return self.fc2(h), self.fc3(h) def reparameterize(self, mu, log_var): std = torch.exp(log_var/2) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): h = F.relu(self.fc4(z)) return F.sigmoid(self.fc5(h)) def forward(self, x): mu, log_var = self.encode(x) z = self.reparameterize(mu, log_var) x_reconst = self.decode(z) return x_reconst, mu, log_varmodel = VAE().to(device)optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 开始训练for epoch in range(num_epochs): for i, (x, _) in enumerate(data_loader): # 前传 x = x.to(device).view(-1, image_size) x_reconst, mu, log_var = model(x) # 计算重建损失和kl散度 reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False) kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # 反向传播和优化 loss = reconst_loss + kl_div optimizer.zero_grad() loss.backward() optimizer.step() if (i+1) % 10 == 0: print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item())) with torch.no_grad(): # 保存采样图像 z = torch.randn(batch_size, z_dim).to(device) out = model.decode(z).view(-1, 1, 28, 28) save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1))) # 保存重建的图像 out, _, _ = model(x) x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3) save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))
本文链接地址:https://www.jiuchutong.com/zhishi/297306.html 转载请保留说明!

上一篇:vue-nginx刷新404问题

下一篇:前端vscode必备插件推荐(墙裂推荐)(vscode写前端代码,如何运行)

  • 微信营销怎样定位?(微信营销定位安卓版下载)

    微信营销怎样定位?(微信营销定位安卓版下载)

  • 惠普键盘灯光怎么开关在哪(惠普键盘灯光怎么切换模式笔记本)

    惠普键盘灯光怎么开关在哪(惠普键盘灯光怎么切换模式笔记本)

  • 红米k40nfc功能怎么用(红米k40nfc功能怎么打开)

    红米k40nfc功能怎么用(红米k40nfc功能怎么打开)

  • vivo NEX 3s的屏幕多大(vivonex3s的屏幕分辨率是多少)

    vivo NEX 3s的屏幕多大(vivonex3s的屏幕分辨率是多少)

  • 苹果11镜头膜怎么撕下来(苹果11镜头膜怎么贴)

    苹果11镜头膜怎么撕下来(苹果11镜头膜怎么贴)

  • p40pro可以无线充电吗(p40 pro无线充)

    p40pro可以无线充电吗(p40 pro无线充)

  • 微信解除情侣空间对方知道吗(微信解除情侣空间对方有提示吗)

    微信解除情侣空间对方知道吗(微信解除情侣空间对方有提示吗)

  • iphone震动声音滋滋声(iphone震动声音滋滋声修一下多少钱)

    iphone震动声音滋滋声(iphone震动声音滋滋声修一下多少钱)

  • 视频通话怎么录视频

    视频通话怎么录视频

  • 剪映怎么放慢视频其中一部分(剪映怎么放慢视频其中一部分音乐不变)

    剪映怎么放慢视频其中一部分(剪映怎么放慢视频其中一部分音乐不变)

  • iphone如何初始化(如何把苹果手机初始化)

    iphone如何初始化(如何把苹果手机初始化)

  • 手机卡剩话费不用了怎么办(手机话费用不完)

    手机卡剩话费不用了怎么办(手机话费用不完)

  • 机器人联网怎么连不上(机器人连网络怎么用在手机上怎么用)

    机器人联网怎么连不上(机器人连网络怎么用在手机上怎么用)

  • 淘宝退货菜鸟裹裹上门取件流程(淘宝退货菜鸟裹裹能看到是什么物品吗)

    淘宝退货菜鸟裹裹上门取件流程(淘宝退货菜鸟裹裹能看到是什么物品吗)

  • ppt点了不保存怎么恢复(ppt点了不保存怎么找回来)

    ppt点了不保存怎么恢复(ppt点了不保存怎么找回来)

  • 抖音怎么恢复播放量(抖音直播怎么找回来)

    抖音怎么恢复播放量(抖音直播怎么找回来)

  • pr怎么做开场字幕特效(pr怎么做开场动画)

    pr怎么做开场字幕特效(pr怎么做开场动画)

  • 手机相册在电脑上是哪个文件夹(手机相册在电脑里英文是)

    手机相册在电脑上是哪个文件夹(手机相册在电脑里英文是)

  • 淘宝评价异常算违规吗(淘宝评价异常对号有影响吗)

    淘宝评价异常算违规吗(淘宝评价异常对号有影响吗)

  • amd svm怎么开启(amdsvm模式在哪里)

    amd svm怎么开启(amdsvm模式在哪里)

  • 苹果手机怎么给软件上锁(苹果手机怎么给安卓手机传照片)

    苹果手机怎么给软件上锁(苹果手机怎么给安卓手机传照片)

  • 苹果手机的airplay在哪里设置(苹果手机的airplay在哪里打开)

    苹果手机的airplay在哪里设置(苹果手机的airplay在哪里打开)

  • 优酷VIP连续包月如何取消(优酷vip连续包月怎么取消)

    优酷VIP连续包月如何取消(优酷vip连续包月怎么取消)

  • 退款成功但是钱没到账(退款成功但是钱没到银行卡,反而银行卡余额没有了)

    退款成功但是钱没到账(退款成功但是钱没到银行卡,反而银行卡余额没有了)

  • Win11正开发新功能:可直接通过任务栏调整音量(win10开发工具在哪)

    Win11正开发新功能:可直接通过任务栏调整音量(win10开发工具在哪)

  • 不得税前扣除的贷款利息
  • 个体户开电子税务局流程
  • 不征税收入的三个条件文件依据
  • 分公司可以享受当地优惠吗?
  • 以件数为印花税计税依据的有哪些
  • 未实际收到的投资收益要纳税调整吗
  • 税务机关和自然人属于平等主体吗
  • 车险 专票
  • 非营利组织项目
  • 欠缴税款会给纳税证明吗
  • 营改增对电信业的影响及对策
  • 公司税务注销了还有风险吗
  • 仓储费专用发票可以抵扣吗?
  • 分类所得申报要申报吗
  • 小企业怎么申请建设用地
  • 企业注销时资本公积怎么处理
  • 公司内收取的礼金
  • 少交的增值税如何记账
  • mac小技巧
  • mac系统如何开启任何来源
  • doc文档隐藏
  • 年末资产减年初资产
  • 对公转账先打钱后转账
  • php写接口实现json文件读取
  • linux gcc命令详解
  • php如何实现
  • php字符串函数有哪些
  • minilauncher是什么
  • 卖机械配件平台有哪些
  • 厂房维修费是制造费用还是管理费用
  • 个人所得税累计扣除是什么意思
  • 什么是长期应付票据
  • 睿智目标检测yolov8
  • php7 nginx
  • 取得企业债券利息
  • 采购发票生成的会计凭证
  • 应交税费未交增值税
  • 弃置费用预计负债的摊余成本
  • 存货盘亏原因不明会计分录
  • 13个点的普票可以抵税吗
  • 通过点击一个按键的游戏
  • 前端培训费用大概多少
  • 按工资申报的工龄怎么算
  • 会议费发票报销附件
  • 工资薪金所得适用的税率是
  • sql2005服务无法启动sql安装方法
  • mysql 自动重启
  • 购入产品用作样品怎么做
  • 财政补贴的政策
  • 工资中的扣款怎么做账
  • 押金是否可以抵扣租金
  • 减免增值税如何申报
  • 编制记账凭证的依据
  • mysql必知必会读后感2000字
  • Mysql5.7在Centos6中的安装方法
  • mysql join查询慢
  • win10文字模糊怎么调整
  • 重庆四日游最佳攻略超详细
  • icqlite.exe进程的详细介绍 icqlite进程的查询 作用是什么
  • windows7不显示移动硬盘
  • python生成二维码添加图片
  • node解析前端formdata数据
  • android中的布局分为6种,分别是
  • 人机交互编程
  • 安装perl模块
  • javascript有哪些常用的属性和方法
  • mysql 导出指定表
  • 不错的意思
  • hover在jquery中的用法
  • logcat read failure
  • android camera setParameters failed 类问题分析总结
  • 迭代 python
  • js foreach倒序
  • python字典常用操作以及字典的嵌套
  • jquery的ajax提交form表单的两种方法小结(推荐)
  • 地税局属于哪个部门管
  • 工会经费什么时候交
  • 国家税务总局23号文件
  • 怎样在开票系统中增加新的名称
  • 云南烟草税收是多少
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设