位置: IT常识 - 正文

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

编辑:rootadmin
GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点

推荐整理分享GANs系列:CGAN(条件GAN)原理简介以及项目代码实现,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

       生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信息y 融入生成器G和判别器D中,属性y可以是任何标签信息, 例如图像的类别、人脸图像的面部表情等。

二、CGAN的基本原理

      cGAN的中心思想是希望 可以控制 GAN 生成的图片,而不 是单纯的随机生成图片。 具体来说,Conditional GAN 在生成器和判别器的输入中 增加了额外的 条件信息,生成器生成的图片只有足够真实 且与条件相符,才能够通过判别器。

      实际上 , 在无条件约束的生成模型中 , 没法控制数据生成的模式。然而,通过额外的信息对模型进行约束,有可能指导数据生成的过程。条件约束可以是类标签 , 可以是图像修补的部分数据, 甚至是来自不同模态的数据

cGAN将 无监督学习 转为 有监督学习 使得网络可以更好地在我们的掌控下进行学习!

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

从公式看,cgan相当于在原始GAN的基础上对生成器部分 和判别器部分都加了一个条件

三、CGAN模型

如果将上图绿色部分的y去掉,就是GAN的原理图。 

 四、CGAN结构

为了实现条件GAN的目的,生成网络和判别网络的原理和 训练方式均要有所改变。

模型部分,在判别器和生成器中都添加了额外信息 y,y 可 以是类别标签或者是其他类型的数据,可以将 y 作为一个 额外的输入层丢入判别器和生成器。 

在生成器中,作者将输入噪声 z 和 y 连在一起隐含表示, 带条件约束这个简单直接的改进被证明非常有效,并广泛用 于后续的相关工作中。论文是在MNIST数据集上以类别标 签为条件变量,生成指定类别的图像。作者还探索了CGAN 在用于图像自动标注的多模态学习上的应用,在MIR Flickr25000数据集上,以图像特征为条件变量,生成该图像的tag的词向量。

 五、CGAN缺陷

cGAN生成的图像虽有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低等,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像风格时对属性特征的 处理方法均受cGAN启发。

六、代码实现,生成指定手写数字import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimimport numpy as npimport matplotlib.pyplot as pltimport torchvisionfrom torchvision import transformsfrom torch.utils import dataimport osimport globfrom PIL import Image# 独热编码# 输入x代表默认的torchvision返回的类比值,class_count类别值为10def one_hot(x, class_count=10): return torch.eye(class_count)[x, :] # 切片选取,第一维选取第x个,第二维全要transform =transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])dataset = torchvision.datasets.MNIST('data', train=True, transform=transform, target_transform=one_hot, download=False)dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)# 定义生成器class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.linear1 = nn.Linear(10, 128 * 7 * 7) self.bn1 = nn.BatchNorm1d(128 * 7 * 7) self.linear2 = nn.Linear(100, 128 * 7 * 7) self.bn2 = nn.BatchNorm1d(128 * 7 * 7) self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=(3, 3), padding=1) self.bn3 = nn.BatchNorm2d(128) self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=2, padding=1) self.bn4 = nn.BatchNorm2d(64) self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=2, padding=1) def forward(self, x1, x2): x1 = F.relu(self.linear1(x1)) x1 = self.bn1(x1) x1 = x1.view(-1, 128, 7, 7) x2 = F.relu(self.linear2(x2)) x2 = self.bn2(x2) x2 = x2.view(-1, 128, 7, 7) x = torch.cat([x1, x2], axis=1) x = F.relu(self.deconv1(x)) x = self.bn3(x) x = F.relu(self.deconv2(x)) x = self.bn4(x) x = torch.tanh(self.deconv3(x)) return x# 定义判别器# input:1,28,28的图片以及长度为10的conditionclass Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.linear = nn.Linear(10, 1*28*28) self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2) self.bn = nn.BatchNorm2d(128) self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值 def forward(self, x1, x2): x1 =F.leaky_relu(self.linear(x1)) x1 = x1.view(-1, 1, 28, 28) x = torch.cat([x1, x2], axis=1) x = F.dropout2d(F.leaky_relu(self.conv1(x))) x = F.dropout2d(F.leaky_relu(self.conv2(x))) x = self.bn(x) x = x.view(-1, 128*6*6) x = torch.sigmoid(self.fc(x)) return x# 初始化模型device = 'cuda' if torch.cuda.is_available() else 'cpu'gen = Generator().to(device)dis = Discriminator().to(device)# 损失计算函数loss_function = torch.nn.BCELoss()# 定义优化器d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)# 定义可视化函数def generate_and_save_images(model, epoch, label_input, noise_input): predictions = np.squeeze(model(label_input, noise_input).cpu().numpy()) fig = plt.figure(figsize=(4, 4)) for i in range(predictions.shape[0]): plt.subplot(4, 4, i + 1) plt.imshow((predictions[i] + 1) / 2, cmap='gray') plt.axis("off") plt.savefig('D:/practice/CGAN/img/image_at_epoch_{:04d}.png'.format(epoch)) plt.show()noise_seed = torch.randn(16, 100, device=device)label_seed = torch.randint(0, 10, size=(16,))label_seed_onehot = one_hot(label_seed).to(device)print(label_seed)# print(label_seed_onehot)# 开始训练D_loss = []G_loss = []# 训练循环for epoch in range(150): d_epoch_loss = 0 g_epoch_loss = 0 count = len(dataloader.dataset) # 对全部的数据集做一次迭代 for step, (img, label) in enumerate(dataloader): img = img.to(device) label = label.to(device) size = img.shape[0] random_noise = torch.randn(size, 100, device=device) d_optim.zero_grad() real_output = dis(label, img) d_real_loss = loss_function(real_output, torch.ones_like(real_output, device=device) ) d_real_loss.backward() #求解梯度 # 得到判别器在生成图像上的损失 gen_img = gen(label,random_noise) fake_output = dis(label, gen_img.detach()) # 判别器输入生成的图片,f_o是对生成图片的预测结果 d_fake_loss = loss_function(fake_output, torch.zeros_like(fake_output, device=device)) d_fake_loss.backward() d_loss = d_real_loss + d_fake_loss d_optim.step() # 优化 # 得到生成器的损失 g_optim.zero_grad() fake_output = dis(label, gen_img) g_loss = loss_function(fake_output, torch.ones_like(fake_output, device=device)) g_loss.backward() g_optim.step() with torch.no_grad(): d_epoch_loss += d_loss.item() g_epoch_loss += g_loss.item() with torch.no_grad(): d_epoch_loss /= count g_epoch_loss /= count D_loss.append(d_epoch_loss) G_loss.append(g_epoch_loss) if epoch % 10 == 0: print('Epoch:', epoch) generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)plt.plot(D_loss, label='D_loss')plt.plot(G_loss, label='G_loss')plt.legend()plt.show()

具体实战代码解读,参考:GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字

本文链接地址:https://www.jiuchutong.com/zhishi/300285.html 转载请保留说明!

上一篇:R-CNN史上最全讲解(rcnn系列详解)

下一篇:结构重参数化(Structural Re-Parameters)PipLine(结构重参数化2d pose)

  • word文档中冲蚀在哪(word文档冲蚀)

    word文档中冲蚀在哪(word文档冲蚀)

  • 闲鱼发票能给别人吗(闲鱼卖东西可以给发票吗)

    闲鱼发票能给别人吗(闲鱼卖东西可以给发票吗)

  • kindle更新慢(kindle更新速度)

    kindle更新慢(kindle更新速度)

  • 下列选项不属于word窗口组成部分的是(下列选项不属于重度哮喘的附加治疗药物的是)

    下列选项不属于word窗口组成部分的是(下列选项不属于重度哮喘的附加治疗药物的是)

  • 苹果笔记本13.3寸多大(苹果笔记本13.3和15.4对比)

    苹果笔记本13.3寸多大(苹果笔记本13.3和15.4对比)

  • ois是啥

    ois是啥

  • 电脑硬盘突然消失不见(电脑硬盘突然消失了怎么回事)

    电脑硬盘突然消失不见(电脑硬盘突然消失了怎么回事)

  • 控制和管理计算机硬件和软件资源的是什么(控制和管理计算机硬件和软件资源,合理地组织)

    控制和管理计算机硬件和软件资源的是什么(控制和管理计算机硬件和软件资源,合理地组织)

  • 苹果x面容进水怎么修复(苹果x面容进水无法用等他水干了能用吗)

    苹果x面容进水怎么修复(苹果x面容进水无法用等他水干了能用吗)

  • 苹果7p微信消息不会弹窗提醒(苹果7p微信消息延迟)

    苹果7p微信消息不会弹窗提醒(苹果7p微信消息延迟)

  • oppo锁屏有广告怎么去掉(oppo锁屏广告怎么彻底关掉)

    oppo锁屏有广告怎么去掉(oppo锁屏广告怎么彻底关掉)

  • 如果微信连着麦闹钟会响吗(如果微信连着麦闹钟响了关了语音就挂了)

    如果微信连着麦闹钟会响吗(如果微信连着麦闹钟响了关了语音就挂了)

  • 抖音私聊已读显示(抖音私聊对方已读)

    抖音私聊已读显示(抖音私聊对方已读)

  • 苹果笔记本128g够装双系统吗(苹果笔记本128g多少钱)

    苹果笔记本128g够装双系统吗(苹果笔记本128g多少钱)

  • 钉钉群直播和视频会议有什么区别(钉钉群直播和视频会议)

    钉钉群直播和视频会议有什么区别(钉钉群直播和视频会议)

  • 怎么改图片的大小kb(怎么改图片的大小M)

    怎么改图片的大小kb(怎么改图片的大小M)

  • vivonex3屏占比多少(vivonex3屏幕多大尺寸)

    vivonex3屏占比多少(vivonex3屏幕多大尺寸)

  • mt6762相当于骁龙哪种(mtkmt6762相当于骁龙多少)

    mt6762相当于骁龙哪种(mtkmt6762相当于骁龙多少)

  • 爱奇艺会员可以几个人用(爱奇艺会员可以几个人一起登录)

    爱奇艺会员可以几个人用(爱奇艺会员可以几个人一起登录)

  • 微信支付立减金是什么(微信支付立减金怎么获得)

    微信支付立减金是什么(微信支付立减金怎么获得)

  • word文档如何竖排文字(word文档如何竖着写字)

    word文档如何竖排文字(word文档如何竖着写字)

  • 华为最新款手机是哪款2022详情(华为最新款手机2023款)

    华为最新款手机是哪款2022详情(华为最新款手机2023款)

  • Dedecms Ask问答系统Rewrite规则(官方的问答)

    Dedecms Ask问答系统Rewrite规则(官方的问答)

  • 注销空白缴销发票流程
  • 土地增值税清算管理规程
  • 建筑公司包工包料账务处理
  • 微信企业版支付
  • 固定资产的入账
  • 建筑施工企业增值税税率是多少
  • 个税手续费发给个人怎么做账
  • 利息收入记借方负数表示增加还是减少
  • 建筑行业预缴增值税可以用进项抵缴吗
  • 费用已付发票未到的预算会计分录
  • 滞纳金按什么比例算
  • 资本公积 转增
  • 待评估资产价值
  • 法院退诉讼费账务处理
  • 制造行业运输费包括哪些
  • 抄税和上报汇总一样吗
  • 异地预缴企业所得税几个点
  • 异地施工需要缴纳什么税
  • 集体土地上的不动产证已经能查询为何房产证拿不到
  • 服务费专票普票
  • 年数总和折旧计算方法
  • 差额纳税计算方法
  • 房企行业其他应收账款的来源是什么
  • 季度所得税预缴税款表中主营业务成本是否包含管理费用
  • 购销合同印花税计税依据
  • 总杠杆系数的计算公司
  • 固定资产折旧会计做账
  • 个税申报密码是什么意思
  • 税务代开的专票未取票,逾期会作废吗?
  • 财务报表未分配利润为负数
  • 购买产品优惠计入什么科目
  • ipad哪款最贵
  • 存放中央银行款项科目按其资金性质
  • 七个超级实用的手机
  • 发票抵扣联能报销吗
  • ConvNeXt V2学习笔记
  • 使用二氧化碳灭火器时人应该站在什么位置
  • chat怎么用
  • unplugin-auto-import github
  • rmt命令 远端磁带传输协议模块
  • php二维数组foreach
  • sql server应用
  • 确定负债排列顺序的依据
  • 自然人独资和个人独资是一样的吗
  • 支付境外货款需要缴纳哪些税费
  • 工程预付款如何缴税
  • 待抵扣进项税计入其他应付账款吗
  • 契税是房价乘以1.5吗
  • 润滑油消费税征收环节税屋
  • 公司收到现金货款怎么存银行
  • 退货收到红字发票怎么办
  • 应交税费为负数在资产负债表中的列报
  • 小微企业即征即退
  • 勾选发票必须当月认证吗
  • 投资收益下期间怎么结转
  • sql server怎么复制表
  • linux信号机制的原理
  • centos安装软件教程
  • 生产环境如何对linux进行合理分区
  • 优盘和硬盘
  • 进程crash是什么意思
  • 文件选项夹在哪里
  • macbook备忘录字数统计
  • win8.1启动项设置
  • windows8.1激活方法
  • win10升级安装视频
  • android break
  • js基于贪心算法实验报告
  • 获取nodejs命令行信息
  • 如何设置div自适应宽度
  • socket怎么用
  • Nodejs+express+ejs简单使用实例代码
  • 使用vs code开发Django
  • 技术总结2000字
  • Android SDK Manager无法更新的解决方案
  • 五四新文化运动究竟新在哪里
  • 单位名称变更后发票还能用吗
  • 营销代码是多少
  • 2021年十大慈善企业
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设