位置: IT常识 - 正文

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

编辑:rootadmin
GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点

推荐整理分享GANs系列:CGAN(条件GAN)原理简介以及项目代码实现,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

       生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信息y 融入生成器G和判别器D中,属性y可以是任何标签信息, 例如图像的类别、人脸图像的面部表情等。

二、CGAN的基本原理

      cGAN的中心思想是希望 可以控制 GAN 生成的图片,而不 是单纯的随机生成图片。 具体来说,Conditional GAN 在生成器和判别器的输入中 增加了额外的 条件信息,生成器生成的图片只有足够真实 且与条件相符,才能够通过判别器。

      实际上 , 在无条件约束的生成模型中 , 没法控制数据生成的模式。然而,通过额外的信息对模型进行约束,有可能指导数据生成的过程。条件约束可以是类标签 , 可以是图像修补的部分数据, 甚至是来自不同模态的数据

cGAN将 无监督学习 转为 有监督学习 使得网络可以更好地在我们的掌控下进行学习!

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

从公式看,cgan相当于在原始GAN的基础上对生成器部分 和判别器部分都加了一个条件

三、CGAN模型

如果将上图绿色部分的y去掉,就是GAN的原理图。 

 四、CGAN结构

为了实现条件GAN的目的,生成网络和判别网络的原理和 训练方式均要有所改变。

模型部分,在判别器和生成器中都添加了额外信息 y,y 可 以是类别标签或者是其他类型的数据,可以将 y 作为一个 额外的输入层丢入判别器和生成器。 

在生成器中,作者将输入噪声 z 和 y 连在一起隐含表示, 带条件约束这个简单直接的改进被证明非常有效,并广泛用 于后续的相关工作中。论文是在MNIST数据集上以类别标 签为条件变量,生成指定类别的图像。作者还探索了CGAN 在用于图像自动标注的多模态学习上的应用,在MIR Flickr25000数据集上,以图像特征为条件变量,生成该图像的tag的词向量。

 五、CGAN缺陷

cGAN生成的图像虽有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低等,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像风格时对属性特征的 处理方法均受cGAN启发。

六、代码实现,生成指定手写数字import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimimport numpy as npimport matplotlib.pyplot as pltimport torchvisionfrom torchvision import transformsfrom torch.utils import dataimport osimport globfrom PIL import Image# 独热编码# 输入x代表默认的torchvision返回的类比值,class_count类别值为10def one_hot(x, class_count=10): return torch.eye(class_count)[x, :] # 切片选取,第一维选取第x个,第二维全要transform =transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])dataset = torchvision.datasets.MNIST('data', train=True, transform=transform, target_transform=one_hot, download=False)dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)# 定义生成器class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.linear1 = nn.Linear(10, 128 * 7 * 7) self.bn1 = nn.BatchNorm1d(128 * 7 * 7) self.linear2 = nn.Linear(100, 128 * 7 * 7) self.bn2 = nn.BatchNorm1d(128 * 7 * 7) self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=(3, 3), padding=1) self.bn3 = nn.BatchNorm2d(128) self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=2, padding=1) self.bn4 = nn.BatchNorm2d(64) self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=2, padding=1) def forward(self, x1, x2): x1 = F.relu(self.linear1(x1)) x1 = self.bn1(x1) x1 = x1.view(-1, 128, 7, 7) x2 = F.relu(self.linear2(x2)) x2 = self.bn2(x2) x2 = x2.view(-1, 128, 7, 7) x = torch.cat([x1, x2], axis=1) x = F.relu(self.deconv1(x)) x = self.bn3(x) x = F.relu(self.deconv2(x)) x = self.bn4(x) x = torch.tanh(self.deconv3(x)) return x# 定义判别器# input:1,28,28的图片以及长度为10的conditionclass Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.linear = nn.Linear(10, 1*28*28) self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2) self.bn = nn.BatchNorm2d(128) self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值 def forward(self, x1, x2): x1 =F.leaky_relu(self.linear(x1)) x1 = x1.view(-1, 1, 28, 28) x = torch.cat([x1, x2], axis=1) x = F.dropout2d(F.leaky_relu(self.conv1(x))) x = F.dropout2d(F.leaky_relu(self.conv2(x))) x = self.bn(x) x = x.view(-1, 128*6*6) x = torch.sigmoid(self.fc(x)) return x# 初始化模型device = 'cuda' if torch.cuda.is_available() else 'cpu'gen = Generator().to(device)dis = Discriminator().to(device)# 损失计算函数loss_function = torch.nn.BCELoss()# 定义优化器d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)# 定义可视化函数def generate_and_save_images(model, epoch, label_input, noise_input): predictions = np.squeeze(model(label_input, noise_input).cpu().numpy()) fig = plt.figure(figsize=(4, 4)) for i in range(predictions.shape[0]): plt.subplot(4, 4, i + 1) plt.imshow((predictions[i] + 1) / 2, cmap='gray') plt.axis("off") plt.savefig('D:/practice/CGAN/img/image_at_epoch_{:04d}.png'.format(epoch)) plt.show()noise_seed = torch.randn(16, 100, device=device)label_seed = torch.randint(0, 10, size=(16,))label_seed_onehot = one_hot(label_seed).to(device)print(label_seed)# print(label_seed_onehot)# 开始训练D_loss = []G_loss = []# 训练循环for epoch in range(150): d_epoch_loss = 0 g_epoch_loss = 0 count = len(dataloader.dataset) # 对全部的数据集做一次迭代 for step, (img, label) in enumerate(dataloader): img = img.to(device) label = label.to(device) size = img.shape[0] random_noise = torch.randn(size, 100, device=device) d_optim.zero_grad() real_output = dis(label, img) d_real_loss = loss_function(real_output, torch.ones_like(real_output, device=device) ) d_real_loss.backward() #求解梯度 # 得到判别器在生成图像上的损失 gen_img = gen(label,random_noise) fake_output = dis(label, gen_img.detach()) # 判别器输入生成的图片,f_o是对生成图片的预测结果 d_fake_loss = loss_function(fake_output, torch.zeros_like(fake_output, device=device)) d_fake_loss.backward() d_loss = d_real_loss + d_fake_loss d_optim.step() # 优化 # 得到生成器的损失 g_optim.zero_grad() fake_output = dis(label, gen_img) g_loss = loss_function(fake_output, torch.ones_like(fake_output, device=device)) g_loss.backward() g_optim.step() with torch.no_grad(): d_epoch_loss += d_loss.item() g_epoch_loss += g_loss.item() with torch.no_grad(): d_epoch_loss /= count g_epoch_loss /= count D_loss.append(d_epoch_loss) G_loss.append(g_epoch_loss) if epoch % 10 == 0: print('Epoch:', epoch) generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)plt.plot(D_loss, label='D_loss')plt.plot(G_loss, label='G_loss')plt.legend()plt.show()

具体实战代码解读,参考:GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字

本文链接地址:https://www.jiuchutong.com/zhishi/300285.html 转载请保留说明!

上一篇:R-CNN史上最全讲解(rcnn系列详解)

下一篇:结构重参数化(Structural Re-Parameters)PipLine(结构重参数化2d pose)

  • 增加微博粉丝的二十个方法途径(增加微博粉丝的方法)

    增加微博粉丝的二十个方法途径(增加微博粉丝的方法)

  • iqoo8怎么开启炫酷灯效(iqoo3炫彩灯在哪)

    iqoo8怎么开启炫酷灯效(iqoo3炫彩灯在哪)

  • 苹果xr怎么添加桌面小插件(苹果xr怎么添加小组件)

    苹果xr怎么添加桌面小插件(苹果xr怎么添加小组件)

  • 钉钉有新消息但是没有提示怎么办(钉钉有新消息但图标不显示)

    钉钉有新消息但是没有提示怎么办(钉钉有新消息但图标不显示)

  • 拼多多怎么做任务领电力(拼多多怎么做任务)

    拼多多怎么做任务领电力(拼多多怎么做任务)

  • 抖音gif怎么保存到手机(抖音gif怎么保存到手机本地相册)

    抖音gif怎么保存到手机(抖音gif怎么保存到手机本地相册)

  • 手机静电导致屏幕失灵(手机静电导致屏幕闪烁)

    手机静电导致屏幕失灵(手机静电导致屏幕闪烁)

  • 5v3a是快充吗(5v3a是快充吗?)

    5v3a是快充吗(5v3a是快充吗?)

  • 6000mah等于多少毫安(6000mah等于多少wh)

    6000mah等于多少毫安(6000mah等于多少wh)

  • 电话手表换卡以后为什么不能用(电话手表换卡以后怎么用)

    电话手表换卡以后为什么不能用(电话手表换卡以后怎么用)

  • 如何把全民k歌里的歌保存到自己的手机里(如何把全民k歌转换成mp3格式)

    如何把全民k歌里的歌保存到自己的手机里(如何把全民k歌转换成mp3格式)

  • 电脑查找替换怎么操作(电脑查找和替换用什么键)

    电脑查找替换怎么操作(电脑查找和替换用什么键)

  • ipad屏幕旋转失灵(ipad屏幕旋转方向不对)

    ipad屏幕旋转失灵(ipad屏幕旋转方向不对)

  • 苹果手机充电频繁闪断(苹果手机充电频率高好吗)

    苹果手机充电频繁闪断(苹果手机充电频率高好吗)

  • 笔记本屏幕亮但是不显示桌面(笔记本屏幕亮但是什么都没有为什么)

    笔记本屏幕亮但是不显示桌面(笔记本屏幕亮但是什么都没有为什么)

  • 华为屏幕指纹颜色怎么改(华为屏幕指纹解锁颜色怎么换)

    华为屏幕指纹颜色怎么改(华为屏幕指纹解锁颜色怎么换)

  • 图形图像属于什么媒体(图形图像属于什么专业类别)

    图形图像属于什么媒体(图形图像属于什么专业类别)

  • 电话响铃多久自动挂断(电话响铃多久自动挂断是什么意思)

    电话响铃多久自动挂断(电话响铃多久自动挂断是什么意思)

  • vivo动态锁屏自定义(vivox27动态锁屏)

    vivo动态锁屏自定义(vivox27动态锁屏)

  • 键盘长度一般多少cm(键盘长度一般多少寸)

    键盘长度一般多少cm(键盘长度一般多少寸)

  • 怎么把b站视频导入相册(怎么把b站视频的音频提取出来)

    怎么把b站视频导入相册(怎么把b站视频的音频提取出来)

  • 显示屏音频接口有啥用(显示屏音频接口在哪)

    显示屏音频接口有啥用(显示屏音频接口在哪)

  • 苹果手机11系列什么时候上市

    苹果手机11系列什么时候上市

  • 华为畅享9plus有红外线遥控吗(华为畅享9plus有nfc功能吗)

    华为畅享9plus有红外线遥控吗(华为畅享9plus有nfc功能吗)

  • 苹果手机车载怎么关闭(苹果手机车载怎么连接)

    苹果手机车载怎么关闭(苹果手机车载怎么连接)

  • 阿里卖家故意不发货(阿里卖家不发货怎么办)

    阿里卖家故意不发货(阿里卖家不发货怎么办)

  • win10任务栏隐藏设置教程(win10任务栏隐藏正在运行的程序)

    win10任务栏隐藏设置教程(win10任务栏隐藏正在运行的程序)

  • 退税收入如何做账
  • 回退税款所属期后怎么返回
  • 金税四期不会对个人产生影响
  • 辞退补偿金按照什么工资算
  • 建筑工程发票来自哪里
  • 建筑业企业生产经营情况表
  • 现金流量怎么影响股票价值
  • 增值税16点税降到13点,补缴税款怎么算
  • 其他应收款账户期初借方余额为35400
  • 医院销售药品是干嘛的
  • 公司拍摄产品的文案
  • 应付汇差是什么意思
  • 收储土地资金会计核算办法
  • 基本生产车间领用周转材料会计分录
  • 支付给银行的借款利息属于什么会计要素
  • 一般代开增值税多少个点?
  • 所得减免优惠明细表减免项目包括几项
  • 去年没有交社保,今年交了有用吗
  • 金税三期能查几年前的发票
  • 公司免费使用我的肖像权用作商业
  • 残疾人就业保障金
  • 工伤后辞职了还可以报工伤
  • 预收款没有发票怎么入账
  • 损益表现金流量表资产负债表
  • 12306打不开怎么回事苹果手机
  • 增值税怎么填表
  • 什么是CMOS什么是BIOS
  • system占用cpu过高怎么解决
  • gazebo中机器人导航在rviz中不显示地图仅限显示轨迹
  • 企业收到款项
  • 房地产企业哪些成本上升了
  • 固定资产计提折旧的原则
  • PHP:oci_field_type()的用法_Oracle函数
  • CUDA(10.2)+PyTorch安装加配置 详细完整教程
  • 开办费对应的现金流量项目
  • 败诉方承担诉讼费缴直接付给法院还是胜诉方
  • 实用的开源软件
  • Chatgpt私有化部署(全流程)
  • 以前年度亏损现在不亏了
  • 卖出周转材料的分录怎么做
  • 出口会计分录该怎么写
  • 企税申报表怎么填
  • 增值税专用发票几个点
  • mysql中的外键的作用
  • 房产税的计算器
  • sqlserver2008密码要求
  • 接受捐赠计入
  • 零售行业的销售额由什么决定
  • 个税申报月份错误怎么更改
  • 无形资产的成本包括增值税吗
  • 车间装修预算表
  • 无形资产减值准备借贷方向增减
  • 有限合伙企业如何报税
  • 如何理解递延所得税资产和负债
  • 员工预支钱要写什么单据
  • 技术服务费可以开专票吗
  • 产品质量问题有赔偿吗
  • 冲销凭证如何做分录
  • 小规模收的专票以后能抵扣吗
  • 购入固定资产一次性税前扣除
  • sql分组having
  • windows中任务栏
  • WINDOWS操作系统最新版本
  • uphclean.exe - uphclean进程是什么意思
  • xp系统怎么打开开机启动项
  • eudcedit.exe
  • win102021年1月大更新
  • win8.1系统升级
  • 工商网银登陆
  • 微软windows8.1
  • linux html编辑器
  • win8一直配置更新
  • dos怎么上网
  • Node.js中的construct
  • Javascript new Date().valueOf()的作用与时间戳由来详解
  • node.js可以跨平台吗
  • Qualcom QMI系列-基本知识介绍
  • 用vue做项目加入购物车是怎么做到的
  • jquery中追加到指定元素末尾
  • 小规模纳税人利润如何缴税
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设