位置: IT常识 - 正文

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

编辑:rootadmin
GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点

推荐整理分享GANs系列:CGAN(条件GAN)原理简介以及项目代码实现,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

       生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信息y 融入生成器G和判别器D中,属性y可以是任何标签信息, 例如图像的类别、人脸图像的面部表情等。

二、CGAN的基本原理

      cGAN的中心思想是希望 可以控制 GAN 生成的图片,而不 是单纯的随机生成图片。 具体来说,Conditional GAN 在生成器和判别器的输入中 增加了额外的 条件信息,生成器生成的图片只有足够真实 且与条件相符,才能够通过判别器。

      实际上 , 在无条件约束的生成模型中 , 没法控制数据生成的模式。然而,通过额外的信息对模型进行约束,有可能指导数据生成的过程。条件约束可以是类标签 , 可以是图像修补的部分数据, 甚至是来自不同模态的数据

cGAN将 无监督学习 转为 有监督学习 使得网络可以更好地在我们的掌控下进行学习!

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

从公式看,cgan相当于在原始GAN的基础上对生成器部分 和判别器部分都加了一个条件

三、CGAN模型

如果将上图绿色部分的y去掉,就是GAN的原理图。 

 四、CGAN结构

为了实现条件GAN的目的,生成网络和判别网络的原理和 训练方式均要有所改变。

模型部分,在判别器和生成器中都添加了额外信息 y,y 可 以是类别标签或者是其他类型的数据,可以将 y 作为一个 额外的输入层丢入判别器和生成器。 

在生成器中,作者将输入噪声 z 和 y 连在一起隐含表示, 带条件约束这个简单直接的改进被证明非常有效,并广泛用 于后续的相关工作中。论文是在MNIST数据集上以类别标 签为条件变量,生成指定类别的图像。作者还探索了CGAN 在用于图像自动标注的多模态学习上的应用,在MIR Flickr25000数据集上,以图像特征为条件变量,生成该图像的tag的词向量。

 五、CGAN缺陷

cGAN生成的图像虽有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低等,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像风格时对属性特征的 处理方法均受cGAN启发。

六、代码实现,生成指定手写数字import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimimport numpy as npimport matplotlib.pyplot as pltimport torchvisionfrom torchvision import transformsfrom torch.utils import dataimport osimport globfrom PIL import Image# 独热编码# 输入x代表默认的torchvision返回的类比值,class_count类别值为10def one_hot(x, class_count=10): return torch.eye(class_count)[x, :] # 切片选取,第一维选取第x个,第二维全要transform =transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])dataset = torchvision.datasets.MNIST('data', train=True, transform=transform, target_transform=one_hot, download=False)dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)# 定义生成器class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.linear1 = nn.Linear(10, 128 * 7 * 7) self.bn1 = nn.BatchNorm1d(128 * 7 * 7) self.linear2 = nn.Linear(100, 128 * 7 * 7) self.bn2 = nn.BatchNorm1d(128 * 7 * 7) self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=(3, 3), padding=1) self.bn3 = nn.BatchNorm2d(128) self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=2, padding=1) self.bn4 = nn.BatchNorm2d(64) self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=2, padding=1) def forward(self, x1, x2): x1 = F.relu(self.linear1(x1)) x1 = self.bn1(x1) x1 = x1.view(-1, 128, 7, 7) x2 = F.relu(self.linear2(x2)) x2 = self.bn2(x2) x2 = x2.view(-1, 128, 7, 7) x = torch.cat([x1, x2], axis=1) x = F.relu(self.deconv1(x)) x = self.bn3(x) x = F.relu(self.deconv2(x)) x = self.bn4(x) x = torch.tanh(self.deconv3(x)) return x# 定义判别器# input:1,28,28的图片以及长度为10的conditionclass Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.linear = nn.Linear(10, 1*28*28) self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2) self.bn = nn.BatchNorm2d(128) self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值 def forward(self, x1, x2): x1 =F.leaky_relu(self.linear(x1)) x1 = x1.view(-1, 1, 28, 28) x = torch.cat([x1, x2], axis=1) x = F.dropout2d(F.leaky_relu(self.conv1(x))) x = F.dropout2d(F.leaky_relu(self.conv2(x))) x = self.bn(x) x = x.view(-1, 128*6*6) x = torch.sigmoid(self.fc(x)) return x# 初始化模型device = 'cuda' if torch.cuda.is_available() else 'cpu'gen = Generator().to(device)dis = Discriminator().to(device)# 损失计算函数loss_function = torch.nn.BCELoss()# 定义优化器d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)# 定义可视化函数def generate_and_save_images(model, epoch, label_input, noise_input): predictions = np.squeeze(model(label_input, noise_input).cpu().numpy()) fig = plt.figure(figsize=(4, 4)) for i in range(predictions.shape[0]): plt.subplot(4, 4, i + 1) plt.imshow((predictions[i] + 1) / 2, cmap='gray') plt.axis("off") plt.savefig('D:/practice/CGAN/img/image_at_epoch_{:04d}.png'.format(epoch)) plt.show()noise_seed = torch.randn(16, 100, device=device)label_seed = torch.randint(0, 10, size=(16,))label_seed_onehot = one_hot(label_seed).to(device)print(label_seed)# print(label_seed_onehot)# 开始训练D_loss = []G_loss = []# 训练循环for epoch in range(150): d_epoch_loss = 0 g_epoch_loss = 0 count = len(dataloader.dataset) # 对全部的数据集做一次迭代 for step, (img, label) in enumerate(dataloader): img = img.to(device) label = label.to(device) size = img.shape[0] random_noise = torch.randn(size, 100, device=device) d_optim.zero_grad() real_output = dis(label, img) d_real_loss = loss_function(real_output, torch.ones_like(real_output, device=device) ) d_real_loss.backward() #求解梯度 # 得到判别器在生成图像上的损失 gen_img = gen(label,random_noise) fake_output = dis(label, gen_img.detach()) # 判别器输入生成的图片,f_o是对生成图片的预测结果 d_fake_loss = loss_function(fake_output, torch.zeros_like(fake_output, device=device)) d_fake_loss.backward() d_loss = d_real_loss + d_fake_loss d_optim.step() # 优化 # 得到生成器的损失 g_optim.zero_grad() fake_output = dis(label, gen_img) g_loss = loss_function(fake_output, torch.ones_like(fake_output, device=device)) g_loss.backward() g_optim.step() with torch.no_grad(): d_epoch_loss += d_loss.item() g_epoch_loss += g_loss.item() with torch.no_grad(): d_epoch_loss /= count g_epoch_loss /= count D_loss.append(d_epoch_loss) G_loss.append(g_epoch_loss) if epoch % 10 == 0: print('Epoch:', epoch) generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)plt.plot(D_loss, label='D_loss')plt.plot(G_loss, label='G_loss')plt.legend()plt.show()

具体实战代码解读,参考:GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字

本文链接地址:https://www.jiuchutong.com/zhishi/300285.html 转载请保留说明!

上一篇:R-CNN史上最全讲解(rcnn系列详解)

下一篇:结构重参数化(Structural Re-Parameters)PipLine(结构重参数化2d pose)

  • 苹果手机微信跟随系统深色模式如何设置(苹果手机微信跟别人视频,对方听不到声音)

    苹果手机微信跟随系统深色模式如何设置(苹果手机微信跟别人视频,对方听不到声音)

  • 华为荣耀20pro有呼吸灯吗(华为荣耀20pro有红外线遥控功能吗)

    华为荣耀20pro有呼吸灯吗(华为荣耀20pro有红外线遥控功能吗)

  • 暴风影音中文字幕乱码(暴风影音中文字幕下载)

    暴风影音中文字幕乱码(暴风影音中文字幕下载)

  • 三星手机系统更新对手机有没有影响(三星手机系统更新好吗)

    三星手机系统更新对手机有没有影响(三星手机系统更新好吗)

  • mix3什么时候升级miui12(mix三什么时候更新)

    mix3什么时候升级miui12(mix三什么时候更新)

  • 电信宽带已连接但不能上网(电信宽带已连接但不能上网怎么解决)

    电信宽带已连接但不能上网(电信宽带已连接但不能上网怎么解决)

  • 爱奇艺vip能几个人用(爱奇艺VIP能几个设备登录)

    爱奇艺vip能几个人用(爱奇艺VIP能几个设备登录)

  • 淘宝卡包在哪里(淘宝卡包劵在哪里)

    淘宝卡包在哪里(淘宝卡包劵在哪里)

  • mr7f2ch/a是什么型号(mr7k2ch/a是什么型号)

    mr7f2ch/a是什么型号(mr7k2ch/a是什么型号)

  • 为什么微信更新了没有深色模式(为什么微信更新包准备失败)

    为什么微信更新了没有深色模式(为什么微信更新包准备失败)

  • 哪款ipad能插电话卡(ipad支持充电)

    哪款ipad能插电话卡(ipad支持充电)

  • 小米cc9e的充电器是快充吗(小米cc9e充电器参数多少w)

    小米cc9e的充电器是快充吗(小米cc9e充电器参数多少w)

  • 西瓜影音视频缓存在哪(西瓜影音视频缓存在哪里)

    西瓜影音视频缓存在哪(西瓜影音视频缓存在哪里)

  • 唯品会钱包在哪里查看(唯品会的唯品钱包在哪?)

    唯品会钱包在哪里查看(唯品会的唯品钱包在哪?)

  • 手机4g速度慢怎么回事(手机4g特别慢)

    手机4g速度慢怎么回事(手机4g特别慢)

  • 手机qq怎么改群头像(手机QQ怎么改群名片)

    手机qq怎么改群头像(手机QQ怎么改群名片)

  • vue怎么合并几个视频(vue项目合并)

    vue怎么合并几个视频(vue项目合并)

  • iphone11怎么拍夜景(如何用iphone11拍夜景)

    iphone11怎么拍夜景(如何用iphone11拍夜景)

  • 探探左划了还能遇到吗(探探被左滑了还可以再滑到对方吗)

    探探左划了还能遇到吗(探探被左滑了还可以再滑到对方吗)

  • 抖音作品赞为什么会消失(抖音赞不多是什么原因)

    抖音作品赞为什么会消失(抖音赞不多是什么原因)

  • 单声道音频需要打开吗(单声道音频需不需要开)

    单声道音频需要打开吗(单声道音频需不需要开)

  • 如何在手机上制作图片视频(如何在手机上制作ppt课件)

    如何在手机上制作图片视频(如何在手机上制作ppt课件)

  • 转发别人的抖音怎么删除(转发别人的抖音视频属于侵犯行为吗)

    转发别人的抖音怎么删除(转发别人的抖音视频属于侵犯行为吗)

  • 随拍只有好友能看吗(随拍只有好友能拍吗)

    随拍只有好友能看吗(随拍只有好友能拍吗)

  • 手机人脸识别可以用照片吗(手机人脸识别可以用视频解锁吗)

    手机人脸识别可以用照片吗(手机人脸识别可以用视频解锁吗)

  • linux怎么在history命令中前面显示日期?(Linux怎么在目录中创建文件)

    linux怎么在history命令中前面显示日期?(Linux怎么在目录中创建文件)

  • 购买车位的税费是多少钱
  • 增值税加计扣除是什么意思啊
  • 什么是进项税额转出
  • 个人所得税红利20%
  • 处置资产开啥发票
  • 发票勾选比账上多
  • 取用备用金要188分
  • 交通运输行业属于什么性质
  • 嵌入式软件产品的批准放行
  • 跨年的费用怎么调整
  • 无偿赠送商品要纳企业所得税吗
  • 库存商品进项税额转出分录怎么写
  • 质量问题扣对方货款账务处理
  • 进项税转出企业所得税账务怎么处理
  • 挂靠被查出来后挂靠费怎么处理?
  • 差错更正要调去年的吗
  • 文化建设事业费逾期申报有罚款吗
  • 增值税发票价税合计是什么意思
  • 洒水车属于免税车辆吗?
  • 小规模纳税人需要做账吗
  • 企业所得税哪些不可以税前扣除
  • 电子钥匙到期怎么办
  • 流氓软件怎么卸载?
  • 不能升级win11的二手电脑值得购买吗
  • 网页游戏玩着卡
  • 无发票 入账
  • 在建工程领用原材料需要进项税转出吗
  • 上市公司的组织形式
  • 销售部门品种多怎么说
  • 取得土地所有权范围内的树如何处理
  • 注册表被恶意锁定怎么恢复正常
  • 开机要按f1才能进系统
  • widows11预览版
  • 建筑业开具发票
  • win10商店发生了错误请稍后再试
  • 城镇土地使用税纳税义务发生时间
  • php框架symfony
  • php图片大小设置
  • 不需要支付的应付款情况说明
  • 企业进项税怎么查询
  • python二叉树遍历算法
  • 购买图书可以开增值税专票么?
  • 外包人员的餐费可以全部扣除吗
  • 购入固定资产的会计科目
  • 小程序渲染是什么意思
  • sqlserver连接不到本地服务器
  • python wordcloud库
  • 原始凭证在账务处理程序中的作用
  • 工业生产的含义
  • 个税申报子女教育有年龄限制吗
  • 为什么开票需要提供开户许可证
  • 固定资产大修理和更新改造的区别
  • 已经认证抵扣的发票会计分录
  • 差旅费属于什么支出类型
  • 以土地使用权投资入股是否缴纳增值税
  • 本年利润在明细里怎么填
  • 大额装修费按几年摊销
  • 担保公司的担保费能退吗
  • 展位费按多少税率
  • 解析包出现错误无法安装怎么办
  • window10 uwp
  • WINDOWS7系统安装包
  • 标签windows
  • win7无法双击打开软件
  • win10一直弹werfault,程序也打不开
  • win7系统怎么关闭防火墙设置
  • win单击变双击
  • win8关闭系统更新
  • 2016年首个熊猫电站是哪一个
  • cocos2dx安装和初步使用
  • 在flash中制作课件一般会遵循什么流程
  • css expression 隔行换色
  • Unity的WWW类的用法整理
  • shell 批量删除
  • unity屏幕坐标 ui坐标
  • 房屋设备租赁费
  • 电子税务局申报密码怎么设置?
  • 个体户定额怎么查询
  • ca登录的用户名和密码分别是什么
  • 企业所得税
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设