位置: IT常识 - 正文

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

编辑:rootadmin
GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点

推荐整理分享GANs系列:CGAN(条件GAN)原理简介以及项目代码实现,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

       生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信息y 融入生成器G和判别器D中,属性y可以是任何标签信息, 例如图像的类别、人脸图像的面部表情等。

二、CGAN的基本原理

      cGAN的中心思想是希望 可以控制 GAN 生成的图片,而不 是单纯的随机生成图片。 具体来说,Conditional GAN 在生成器和判别器的输入中 增加了额外的 条件信息,生成器生成的图片只有足够真实 且与条件相符,才能够通过判别器。

      实际上 , 在无条件约束的生成模型中 , 没法控制数据生成的模式。然而,通过额外的信息对模型进行约束,有可能指导数据生成的过程。条件约束可以是类标签 , 可以是图像修补的部分数据, 甚至是来自不同模态的数据

cGAN将 无监督学习 转为 有监督学习 使得网络可以更好地在我们的掌控下进行学习!

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

从公式看,cgan相当于在原始GAN的基础上对生成器部分 和判别器部分都加了一个条件

三、CGAN模型

如果将上图绿色部分的y去掉,就是GAN的原理图。 

 四、CGAN结构

为了实现条件GAN的目的,生成网络和判别网络的原理和 训练方式均要有所改变。

模型部分,在判别器和生成器中都添加了额外信息 y,y 可 以是类别标签或者是其他类型的数据,可以将 y 作为一个 额外的输入层丢入判别器和生成器。 

在生成器中,作者将输入噪声 z 和 y 连在一起隐含表示, 带条件约束这个简单直接的改进被证明非常有效,并广泛用 于后续的相关工作中。论文是在MNIST数据集上以类别标 签为条件变量,生成指定类别的图像。作者还探索了CGAN 在用于图像自动标注的多模态学习上的应用,在MIR Flickr25000数据集上,以图像特征为条件变量,生成该图像的tag的词向量。

 五、CGAN缺陷

cGAN生成的图像虽有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低等,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像风格时对属性特征的 处理方法均受cGAN启发。

六、代码实现,生成指定手写数字import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimimport numpy as npimport matplotlib.pyplot as pltimport torchvisionfrom torchvision import transformsfrom torch.utils import dataimport osimport globfrom PIL import Image# 独热编码# 输入x代表默认的torchvision返回的类比值,class_count类别值为10def one_hot(x, class_count=10): return torch.eye(class_count)[x, :] # 切片选取,第一维选取第x个,第二维全要transform =transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])dataset = torchvision.datasets.MNIST('data', train=True, transform=transform, target_transform=one_hot, download=False)dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)# 定义生成器class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.linear1 = nn.Linear(10, 128 * 7 * 7) self.bn1 = nn.BatchNorm1d(128 * 7 * 7) self.linear2 = nn.Linear(100, 128 * 7 * 7) self.bn2 = nn.BatchNorm1d(128 * 7 * 7) self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=(3, 3), padding=1) self.bn3 = nn.BatchNorm2d(128) self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=2, padding=1) self.bn4 = nn.BatchNorm2d(64) self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=2, padding=1) def forward(self, x1, x2): x1 = F.relu(self.linear1(x1)) x1 = self.bn1(x1) x1 = x1.view(-1, 128, 7, 7) x2 = F.relu(self.linear2(x2)) x2 = self.bn2(x2) x2 = x2.view(-1, 128, 7, 7) x = torch.cat([x1, x2], axis=1) x = F.relu(self.deconv1(x)) x = self.bn3(x) x = F.relu(self.deconv2(x)) x = self.bn4(x) x = torch.tanh(self.deconv3(x)) return x# 定义判别器# input:1,28,28的图片以及长度为10的conditionclass Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.linear = nn.Linear(10, 1*28*28) self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2) self.bn = nn.BatchNorm2d(128) self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值 def forward(self, x1, x2): x1 =F.leaky_relu(self.linear(x1)) x1 = x1.view(-1, 1, 28, 28) x = torch.cat([x1, x2], axis=1) x = F.dropout2d(F.leaky_relu(self.conv1(x))) x = F.dropout2d(F.leaky_relu(self.conv2(x))) x = self.bn(x) x = x.view(-1, 128*6*6) x = torch.sigmoid(self.fc(x)) return x# 初始化模型device = 'cuda' if torch.cuda.is_available() else 'cpu'gen = Generator().to(device)dis = Discriminator().to(device)# 损失计算函数loss_function = torch.nn.BCELoss()# 定义优化器d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)# 定义可视化函数def generate_and_save_images(model, epoch, label_input, noise_input): predictions = np.squeeze(model(label_input, noise_input).cpu().numpy()) fig = plt.figure(figsize=(4, 4)) for i in range(predictions.shape[0]): plt.subplot(4, 4, i + 1) plt.imshow((predictions[i] + 1) / 2, cmap='gray') plt.axis("off") plt.savefig('D:/practice/CGAN/img/image_at_epoch_{:04d}.png'.format(epoch)) plt.show()noise_seed = torch.randn(16, 100, device=device)label_seed = torch.randint(0, 10, size=(16,))label_seed_onehot = one_hot(label_seed).to(device)print(label_seed)# print(label_seed_onehot)# 开始训练D_loss = []G_loss = []# 训练循环for epoch in range(150): d_epoch_loss = 0 g_epoch_loss = 0 count = len(dataloader.dataset) # 对全部的数据集做一次迭代 for step, (img, label) in enumerate(dataloader): img = img.to(device) label = label.to(device) size = img.shape[0] random_noise = torch.randn(size, 100, device=device) d_optim.zero_grad() real_output = dis(label, img) d_real_loss = loss_function(real_output, torch.ones_like(real_output, device=device) ) d_real_loss.backward() #求解梯度 # 得到判别器在生成图像上的损失 gen_img = gen(label,random_noise) fake_output = dis(label, gen_img.detach()) # 判别器输入生成的图片,f_o是对生成图片的预测结果 d_fake_loss = loss_function(fake_output, torch.zeros_like(fake_output, device=device)) d_fake_loss.backward() d_loss = d_real_loss + d_fake_loss d_optim.step() # 优化 # 得到生成器的损失 g_optim.zero_grad() fake_output = dis(label, gen_img) g_loss = loss_function(fake_output, torch.ones_like(fake_output, device=device)) g_loss.backward() g_optim.step() with torch.no_grad(): d_epoch_loss += d_loss.item() g_epoch_loss += g_loss.item() with torch.no_grad(): d_epoch_loss /= count g_epoch_loss /= count D_loss.append(d_epoch_loss) G_loss.append(g_epoch_loss) if epoch % 10 == 0: print('Epoch:', epoch) generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)plt.plot(D_loss, label='D_loss')plt.plot(G_loss, label='G_loss')plt.legend()plt.show()

具体实战代码解读,参考:GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字

本文链接地址:https://www.jiuchutong.com/zhishi/300285.html 转载请保留说明!

上一篇:R-CNN史上最全讲解(rcnn系列详解)

下一篇:结构重参数化(Structural Re-Parameters)PipLine(结构重参数化2d pose)

  • mysql 查看版本(mysql如何查看版本)

    mysql 查看版本(mysql如何查看版本)

  • 米兔手表如何恢复出厂设置(米兔手表如何恢复)

    米兔手表如何恢复出厂设置(米兔手表如何恢复)

  • users是什么文件夹能删除吗(电脑users是什么文件)

    users是什么文件夹能删除吗(电脑users是什么文件)

  • 华为荣耀10手机音量小怎么办(华为荣耀10手机壳)

    华为荣耀10手机音量小怎么办(华为荣耀10手机壳)

  • 全民k歌耳机滋滋滋的响(全民k歌耳机嗡嗡响)

    全民k歌耳机滋滋滋的响(全民k歌耳机嗡嗡响)

  • 电脑打字的顿号在上面怎样打(电脑打字顿号怎么输入都是斜杠)

    电脑打字的顿号在上面怎样打(电脑打字顿号怎么输入都是斜杠)

  • 电脑修改照片格式为jpg怎么弄(电脑修改照片格式的软件)

    电脑修改照片格式为jpg怎么弄(电脑修改照片格式的软件)

  • 苹果双摄像头是哪款(苹果双摄像头是苹果几)

    苹果双摄像头是哪款(苹果双摄像头是苹果几)

  • 小米10pro指纹解锁不灵敏(小米10pro指纹解锁屏幕闪烁一下)

    小米10pro指纹解锁不灵敏(小米10pro指纹解锁屏幕闪烁一下)

  • 电脑键盘add是哪个键(键盘ad键不好用怎么办)

    电脑键盘add是哪个键(键盘ad键不好用怎么办)

  • mac笔记本是什么牌子(macbook笔记本电脑)

    mac笔记本是什么牌子(macbook笔记本电脑)

  • iphone原装电池是什么牌子的(iphone原装电池长啥样)

    iphone原装电池是什么牌子的(iphone原装电池长啥样)

  • p30锁屏时间位置(华为p30pro锁屏时间位置移动)

    p30锁屏时间位置(华为p30pro锁屏时间位置移动)

  • dualbios是什么意思(电脑出现dualbios是什么意思)

    dualbios是什么意思(电脑出现dualbios是什么意思)

  • iphone7支持pd快充吗(iphone7支持快充吗能用2.1a充电吗)

    iphone7支持pd快充吗(iphone7支持快充吗能用2.1a充电吗)

  • 华为手机信号出现沙漏(华为手机信号出现地球)

    华为手机信号出现沙漏(华为手机信号出现地球)

  • 华为抬起唤醒怎么设置(华为抬起唤醒怎么不亮了)

    华为抬起唤醒怎么设置(华为抬起唤醒怎么不亮了)

  • mn5t2lla是什么版(mn6l2j/a是什么版本)

    mn5t2lla是什么版(mn6l2j/a是什么版本)

  • 荣耀20一键清理在哪里(荣耀20一键清理可以自定义吗)

    荣耀20一键清理在哪里(荣耀20一键清理可以自定义吗)

  • 荣耀10怎么开启红外线(荣耀10怎么开启游戏模式)

    荣耀10怎么开启红外线(荣耀10怎么开启游戏模式)

  • 苹果11抗摔吗(苹果11抗摔嘛)

    苹果11抗摔吗(苹果11抗摔嘛)

  • vivox21是多少英寸的(vivox21是多大屏幕尺寸)

    vivox21是多少英寸的(vivox21是多大屏幕尺寸)

  • 荣耀20和华为nova5pro对比(荣耀20和华为nova5pro哪个好)

    荣耀20和华为nova5pro对比(荣耀20和华为nova5pro哪个好)

  • 抖音怎样设置不让转发(抖音怎样设置不让别人看到我的关注)

    抖音怎样设置不让转发(抖音怎样设置不让别人看到我的关注)

  • 如何关闭oppo新闻资讯(如何关闭oppo新的输入功能)

    如何关闭oppo新闻资讯(如何关闭oppo新的输入功能)

  • iphonexr怎么插耳机(xr在哪里插耳机)

    iphonexr怎么插耳机(xr在哪里插耳机)

  • Win10免费多屏协同不可用怎么办? 投影到此电脑的使用方法(win10 多屏协同)

    Win10免费多屏协同不可用怎么办? 投影到此电脑的使用方法(win10 多屏协同)

  • 冰岛羊 (© John Porter LRPS/Alamy)

    冰岛羊 (© John Porter LRPS/Alamy)

  • 2020年小微企业所得税税率
  • 增值税专票如何作废
  • 收取境外服务费收入如何开票
  • 应交增值税二级科目借贷方向
  • 销项税是什么意思进项税是什么意思
  • 合伙企业的费用在何处扣除
  • 无进项开票税点
  • 没有工会的企业怎么发福利
  • 存货的初始计量应以取得存货的实际成本
  • 固定资产原值包含进项税吗
  • 交了进口增值税还要交增值税吗
  • 福利费进项税额转出会计分录账务处理
  • 个人到税务局开增值税专用发票
  • 个税手续费返还计入哪个科目
  • 新个税法案专项扣除如何实施
  • 没有合同可以收违约金
  • 公司按最低标准缴纳社保
  • 住房按揭贷款贷后检查内容
  • 搅拌站是自用的账务如何做分录?
  • 接收虚开增值税专用发票要判刑吗
  • 建安业一般纳税人企业所得税率是多少
  • 小规模纳税人季度多少免税
  • 附加税减半征收计提和缴纳的会计分录
  • 出口退税企业分类
  • 出租不动产什么时候交税
  • 进项税和销项税的借贷方向
  • 销售收入里面包括免税收入呢
  • 食用盐的增值税是多少
  • 电商平台收取的手续费
  • 股权激励取消怎么处理?
  • 购进货物无偿赠送其他单位
  • 在windows 10中
  • php获取指定日期的时间戳
  • 取得基建借款分录
  • 企业公益性捐赠扣除比例
  • linux动态扩容
  • framework模块
  • 专用发票密码区模糊了影响报销吗
  • 工程施工科目下的招待费,汇算清缴
  • 票据权利期限可以缩短吗
  • 研发专利什么意思
  • 待报解预算收入待结算财政款项
  • python中assert()函数
  • 发票红字冲回账务处理
  • 织梦使用教程
  • 计提基建借款利息会计分录
  • 营业税减免会计分录
  • sqlserver的isnull
  • 投资性房地产租金
  • 总账和明细账有那些
  • 退货对方不开具红字发票怎么办
  • 没有按时对账
  • 选择税的计算方法
  • 房地产开发企业资质管理规定
  • 支出费用的区别
  • 利润总额和未分成比例
  • sqlserver over
  • 此数据库中不存在用户或角色
  • mysql函数大全以及举例
  • ubuntuone
  • vmware安装macos10.15
  • win8 怎么样
  • centos硬盘扩容
  • kochsysteme
  • 移动宽带解绑怎么办理
  • Unity3D游戏开发标准教程吴亚峰于复兴人民邮电出版社
  • html淘宝搜索框代码
  • jQuery+ajax的资源回收处理机制分析
  • ubuntu14重置密码
  • Linux系统安装字体
  • vuex详细教程
  • 跟踪子弹
  • python条件怎么算合法
  • android中常用的布局是
  • js框架开发实例
  • python中的set方法
  • 电子税务局怎么添加开票员
  • 金条如何销售
  • 掌上12333怎么交社保卡费用
  • 营业费用指哪些
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设