位置: IT常识 - 正文

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

编辑:rootadmin
GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点

推荐整理分享GANs系列:CGAN(条件GAN)原理简介以及项目代码实现,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

       生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信息y 融入生成器G和判别器D中,属性y可以是任何标签信息, 例如图像的类别、人脸图像的面部表情等。

二、CGAN的基本原理

      cGAN的中心思想是希望 可以控制 GAN 生成的图片,而不 是单纯的随机生成图片。 具体来说,Conditional GAN 在生成器和判别器的输入中 增加了额外的 条件信息,生成器生成的图片只有足够真实 且与条件相符,才能够通过判别器。

      实际上 , 在无条件约束的生成模型中 , 没法控制数据生成的模式。然而,通过额外的信息对模型进行约束,有可能指导数据生成的过程。条件约束可以是类标签 , 可以是图像修补的部分数据, 甚至是来自不同模态的数据

cGAN将 无监督学习 转为 有监督学习 使得网络可以更好地在我们的掌控下进行学习!

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

从公式看,cgan相当于在原始GAN的基础上对生成器部分 和判别器部分都加了一个条件

三、CGAN模型

如果将上图绿色部分的y去掉,就是GAN的原理图。 

 四、CGAN结构

为了实现条件GAN的目的,生成网络和判别网络的原理和 训练方式均要有所改变。

模型部分,在判别器和生成器中都添加了额外信息 y,y 可 以是类别标签或者是其他类型的数据,可以将 y 作为一个 额外的输入层丢入判别器和生成器。 

在生成器中,作者将输入噪声 z 和 y 连在一起隐含表示, 带条件约束这个简单直接的改进被证明非常有效,并广泛用 于后续的相关工作中。论文是在MNIST数据集上以类别标 签为条件变量,生成指定类别的图像。作者还探索了CGAN 在用于图像自动标注的多模态学习上的应用,在MIR Flickr25000数据集上,以图像特征为条件变量,生成该图像的tag的词向量。

 五、CGAN缺陷

cGAN生成的图像虽有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低等,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像风格时对属性特征的 处理方法均受cGAN启发。

六、代码实现,生成指定手写数字import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimimport numpy as npimport matplotlib.pyplot as pltimport torchvisionfrom torchvision import transformsfrom torch.utils import dataimport osimport globfrom PIL import Image# 独热编码# 输入x代表默认的torchvision返回的类比值,class_count类别值为10def one_hot(x, class_count=10): return torch.eye(class_count)[x, :] # 切片选取,第一维选取第x个,第二维全要transform =transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])dataset = torchvision.datasets.MNIST('data', train=True, transform=transform, target_transform=one_hot, download=False)dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)# 定义生成器class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.linear1 = nn.Linear(10, 128 * 7 * 7) self.bn1 = nn.BatchNorm1d(128 * 7 * 7) self.linear2 = nn.Linear(100, 128 * 7 * 7) self.bn2 = nn.BatchNorm1d(128 * 7 * 7) self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=(3, 3), padding=1) self.bn3 = nn.BatchNorm2d(128) self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=2, padding=1) self.bn4 = nn.BatchNorm2d(64) self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=2, padding=1) def forward(self, x1, x2): x1 = F.relu(self.linear1(x1)) x1 = self.bn1(x1) x1 = x1.view(-1, 128, 7, 7) x2 = F.relu(self.linear2(x2)) x2 = self.bn2(x2) x2 = x2.view(-1, 128, 7, 7) x = torch.cat([x1, x2], axis=1) x = F.relu(self.deconv1(x)) x = self.bn3(x) x = F.relu(self.deconv2(x)) x = self.bn4(x) x = torch.tanh(self.deconv3(x)) return x# 定义判别器# input:1,28,28的图片以及长度为10的conditionclass Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.linear = nn.Linear(10, 1*28*28) self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2) self.bn = nn.BatchNorm2d(128) self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值 def forward(self, x1, x2): x1 =F.leaky_relu(self.linear(x1)) x1 = x1.view(-1, 1, 28, 28) x = torch.cat([x1, x2], axis=1) x = F.dropout2d(F.leaky_relu(self.conv1(x))) x = F.dropout2d(F.leaky_relu(self.conv2(x))) x = self.bn(x) x = x.view(-1, 128*6*6) x = torch.sigmoid(self.fc(x)) return x# 初始化模型device = 'cuda' if torch.cuda.is_available() else 'cpu'gen = Generator().to(device)dis = Discriminator().to(device)# 损失计算函数loss_function = torch.nn.BCELoss()# 定义优化器d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)# 定义可视化函数def generate_and_save_images(model, epoch, label_input, noise_input): predictions = np.squeeze(model(label_input, noise_input).cpu().numpy()) fig = plt.figure(figsize=(4, 4)) for i in range(predictions.shape[0]): plt.subplot(4, 4, i + 1) plt.imshow((predictions[i] + 1) / 2, cmap='gray') plt.axis("off") plt.savefig('D:/practice/CGAN/img/image_at_epoch_{:04d}.png'.format(epoch)) plt.show()noise_seed = torch.randn(16, 100, device=device)label_seed = torch.randint(0, 10, size=(16,))label_seed_onehot = one_hot(label_seed).to(device)print(label_seed)# print(label_seed_onehot)# 开始训练D_loss = []G_loss = []# 训练循环for epoch in range(150): d_epoch_loss = 0 g_epoch_loss = 0 count = len(dataloader.dataset) # 对全部的数据集做一次迭代 for step, (img, label) in enumerate(dataloader): img = img.to(device) label = label.to(device) size = img.shape[0] random_noise = torch.randn(size, 100, device=device) d_optim.zero_grad() real_output = dis(label, img) d_real_loss = loss_function(real_output, torch.ones_like(real_output, device=device) ) d_real_loss.backward() #求解梯度 # 得到判别器在生成图像上的损失 gen_img = gen(label,random_noise) fake_output = dis(label, gen_img.detach()) # 判别器输入生成的图片,f_o是对生成图片的预测结果 d_fake_loss = loss_function(fake_output, torch.zeros_like(fake_output, device=device)) d_fake_loss.backward() d_loss = d_real_loss + d_fake_loss d_optim.step() # 优化 # 得到生成器的损失 g_optim.zero_grad() fake_output = dis(label, gen_img) g_loss = loss_function(fake_output, torch.ones_like(fake_output, device=device)) g_loss.backward() g_optim.step() with torch.no_grad(): d_epoch_loss += d_loss.item() g_epoch_loss += g_loss.item() with torch.no_grad(): d_epoch_loss /= count g_epoch_loss /= count D_loss.append(d_epoch_loss) G_loss.append(g_epoch_loss) if epoch % 10 == 0: print('Epoch:', epoch) generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)plt.plot(D_loss, label='D_loss')plt.plot(G_loss, label='G_loss')plt.legend()plt.show()

具体实战代码解读,参考:GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字

本文链接地址:https://www.jiuchutong.com/zhishi/300285.html 转载请保留说明!

上一篇:R-CNN史上最全讲解(rcnn系列详解)

下一篇:结构重参数化(Structural Re-Parameters)PipLine(结构重参数化2d pose)

  • 税务申报后就可以清盘了吗
  • 丢失增值税专用发票最新规定
  • 印花税纳税义务人有哪些
  • 材料费用发票的记账凭证
  • 收到股东投资款怎么做账
  • 交易性金融资产的账务处理
  • 合同负债里面含增值税吗
  • 银行汇票存款和银行存款的区别
  • 行政单位应缴财政收入预算会计分录
  • 内账应收应付算利润吗
  • 招标代理公司转让
  • 公司阅览室布置图片
  • 企业一般户可以扣税吗
  • 工会经费计税依据是应发工资还是实发工资
  • 商品流通企业会计心得体会3000字
  • 一般纳税人出租不动产增值税税率
  • 亏损的递延所得税怎么理解
  • 小规模季报资产总额填错了有影响吗
  • 其他应付款调整到其他应收款
  • 跨年如何冲减预提费用?
  • 销售收入大于纳税申报销售收入
  • 关闭guest账户
  • u盘启动器安装系统
  • 费用跨年的分录怎么做
  • 计提税金怎么提
  • 荣耀x10的鸿蒙系统怎么开启
  • 企业当期产生的外币报表折算差额
  • Win10 (21H1)Build 19043.1266更新补丁KB5005611正式版发布:附修复更新内容
  • vpengine.exe进程
  • PHP:oci_fetch_array()的用法_Oracle函数
  • php中ajax
  • php异步处理方案
  • 深度学习第一步——Pytorch-Gpu环境配置:Win11/Win10+Cuda10.2+cuDNN8.5.0+Pytorch1.8.0(步步巨细,少走十年弯路)
  • php获取表单数据保存到mysql中
  • php序列化和反序列化函数
  • php定义一个二维数组
  • test指令怎么用
  • 电子发票开出后如何查看
  • 小企业短期借款科目的贷方登记
  • python搞自动化
  • 上季度忘记申报个税了
  • 企业会计本年利润
  • 政府补助的分类包括
  • 会计核算职能有全面性吗
  • 增值税税控系列是什么
  • 原材料属于固定资本还是流动资本
  • 土地承包经营权上的房屋
  • 零申报企业年报资产状况信息怎么填
  • 固定资产清理属于什么科目借方增加还是减少
  • 稳岗返还计入营业外收入
  • 报名费无发票要补交吗
  • 增值税其他免税销售额
  • 仓库做账应该注意些哪些事项
  • 期后事项的分类及处理原则
  • 知识经济对会计的影响论文
  • win8n
  • ubuntu无法正常开机
  • win8.1 升级
  • xp系统无法预览图片
  • 运行ghost
  • win10更新后安装包会自动删除吗
  • 如何将win10系统从c盘迁移到d盘
  • linux查看所有硬件信息命令
  • Win10系统中怎么将文件夹进行压缩
  • win10系统安装cad2008的注册机无法打开
  • cocos2d怎么用
  • nodejs的express框架详解
  • wifi显示开发状态
  • bash脚本语法
  • k mean python
  • python安装pip.whl
  • 用js实现冒泡排序
  • nodejs gyp
  • JavaScript中length属性的使用方法
  • 用shell脚本创建用户
  • python 父类方法
  • 怎么查询河南省考职位报名人数
  • 防伪税控维护费普通发票怎么申报
  • 税务异地协查系统管理办法
  • 财税公司经营范围介绍
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设