位置: IT常识 - 正文

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

编辑:rootadmin
GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点

推荐整理分享GANs系列:CGAN(条件GAN)原理简介以及项目代码实现,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

       生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信息y 融入生成器G和判别器D中,属性y可以是任何标签信息, 例如图像的类别、人脸图像的面部表情等。

二、CGAN的基本原理

      cGAN的中心思想是希望 可以控制 GAN 生成的图片,而不 是单纯的随机生成图片。 具体来说,Conditional GAN 在生成器和判别器的输入中 增加了额外的 条件信息,生成器生成的图片只有足够真实 且与条件相符,才能够通过判别器。

      实际上 , 在无条件约束的生成模型中 , 没法控制数据生成的模式。然而,通过额外的信息对模型进行约束,有可能指导数据生成的过程。条件约束可以是类标签 , 可以是图像修补的部分数据, 甚至是来自不同模态的数据

cGAN将 无监督学习 转为 有监督学习 使得网络可以更好地在我们的掌控下进行学习!

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

从公式看,cgan相当于在原始GAN的基础上对生成器部分 和判别器部分都加了一个条件

三、CGAN模型

如果将上图绿色部分的y去掉,就是GAN的原理图。 

 四、CGAN结构

为了实现条件GAN的目的,生成网络和判别网络的原理和 训练方式均要有所改变。

模型部分,在判别器和生成器中都添加了额外信息 y,y 可 以是类别标签或者是其他类型的数据,可以将 y 作为一个 额外的输入层丢入判别器和生成器。 

在生成器中,作者将输入噪声 z 和 y 连在一起隐含表示, 带条件约束这个简单直接的改进被证明非常有效,并广泛用 于后续的相关工作中。论文是在MNIST数据集上以类别标 签为条件变量,生成指定类别的图像。作者还探索了CGAN 在用于图像自动标注的多模态学习上的应用,在MIR Flickr25000数据集上,以图像特征为条件变量,生成该图像的tag的词向量。

 五、CGAN缺陷

cGAN生成的图像虽有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低等,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像风格时对属性特征的 处理方法均受cGAN启发。

六、代码实现,生成指定手写数字import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimimport numpy as npimport matplotlib.pyplot as pltimport torchvisionfrom torchvision import transformsfrom torch.utils import dataimport osimport globfrom PIL import Image# 独热编码# 输入x代表默认的torchvision返回的类比值,class_count类别值为10def one_hot(x, class_count=10): return torch.eye(class_count)[x, :] # 切片选取,第一维选取第x个,第二维全要transform =transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])dataset = torchvision.datasets.MNIST('data', train=True, transform=transform, target_transform=one_hot, download=False)dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)# 定义生成器class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.linear1 = nn.Linear(10, 128 * 7 * 7) self.bn1 = nn.BatchNorm1d(128 * 7 * 7) self.linear2 = nn.Linear(100, 128 * 7 * 7) self.bn2 = nn.BatchNorm1d(128 * 7 * 7) self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=(3, 3), padding=1) self.bn3 = nn.BatchNorm2d(128) self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=2, padding=1) self.bn4 = nn.BatchNorm2d(64) self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=2, padding=1) def forward(self, x1, x2): x1 = F.relu(self.linear1(x1)) x1 = self.bn1(x1) x1 = x1.view(-1, 128, 7, 7) x2 = F.relu(self.linear2(x2)) x2 = self.bn2(x2) x2 = x2.view(-1, 128, 7, 7) x = torch.cat([x1, x2], axis=1) x = F.relu(self.deconv1(x)) x = self.bn3(x) x = F.relu(self.deconv2(x)) x = self.bn4(x) x = torch.tanh(self.deconv3(x)) return x# 定义判别器# input:1,28,28的图片以及长度为10的conditionclass Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.linear = nn.Linear(10, 1*28*28) self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2) self.bn = nn.BatchNorm2d(128) self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值 def forward(self, x1, x2): x1 =F.leaky_relu(self.linear(x1)) x1 = x1.view(-1, 1, 28, 28) x = torch.cat([x1, x2], axis=1) x = F.dropout2d(F.leaky_relu(self.conv1(x))) x = F.dropout2d(F.leaky_relu(self.conv2(x))) x = self.bn(x) x = x.view(-1, 128*6*6) x = torch.sigmoid(self.fc(x)) return x# 初始化模型device = 'cuda' if torch.cuda.is_available() else 'cpu'gen = Generator().to(device)dis = Discriminator().to(device)# 损失计算函数loss_function = torch.nn.BCELoss()# 定义优化器d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)# 定义可视化函数def generate_and_save_images(model, epoch, label_input, noise_input): predictions = np.squeeze(model(label_input, noise_input).cpu().numpy()) fig = plt.figure(figsize=(4, 4)) for i in range(predictions.shape[0]): plt.subplot(4, 4, i + 1) plt.imshow((predictions[i] + 1) / 2, cmap='gray') plt.axis("off") plt.savefig('D:/practice/CGAN/img/image_at_epoch_{:04d}.png'.format(epoch)) plt.show()noise_seed = torch.randn(16, 100, device=device)label_seed = torch.randint(0, 10, size=(16,))label_seed_onehot = one_hot(label_seed).to(device)print(label_seed)# print(label_seed_onehot)# 开始训练D_loss = []G_loss = []# 训练循环for epoch in range(150): d_epoch_loss = 0 g_epoch_loss = 0 count = len(dataloader.dataset) # 对全部的数据集做一次迭代 for step, (img, label) in enumerate(dataloader): img = img.to(device) label = label.to(device) size = img.shape[0] random_noise = torch.randn(size, 100, device=device) d_optim.zero_grad() real_output = dis(label, img) d_real_loss = loss_function(real_output, torch.ones_like(real_output, device=device) ) d_real_loss.backward() #求解梯度 # 得到判别器在生成图像上的损失 gen_img = gen(label,random_noise) fake_output = dis(label, gen_img.detach()) # 判别器输入生成的图片,f_o是对生成图片的预测结果 d_fake_loss = loss_function(fake_output, torch.zeros_like(fake_output, device=device)) d_fake_loss.backward() d_loss = d_real_loss + d_fake_loss d_optim.step() # 优化 # 得到生成器的损失 g_optim.zero_grad() fake_output = dis(label, gen_img) g_loss = loss_function(fake_output, torch.ones_like(fake_output, device=device)) g_loss.backward() g_optim.step() with torch.no_grad(): d_epoch_loss += d_loss.item() g_epoch_loss += g_loss.item() with torch.no_grad(): d_epoch_loss /= count g_epoch_loss /= count D_loss.append(d_epoch_loss) G_loss.append(g_epoch_loss) if epoch % 10 == 0: print('Epoch:', epoch) generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)plt.plot(D_loss, label='D_loss')plt.plot(G_loss, label='G_loss')plt.legend()plt.show()

具体实战代码解读,参考:GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字

本文链接地址:https://www.jiuchutong.com/zhishi/300285.html 转载请保留说明!

上一篇:R-CNN史上最全讲解(rcnn系列详解)

下一篇:结构重参数化(Structural Re-Parameters)PipLine(结构重参数化2d pose)

  • 淘宝退款到花呗怎么查询到账(淘宝退款到花呗,花呗已经付清,怎么办)

    淘宝退款到花呗怎么查询到账(淘宝退款到花呗,花呗已经付清,怎么办)

  • 荣耀30pro如何隐藏软件(荣耀30Pro如何隐藏软件)

    荣耀30pro如何隐藏软件(荣耀30Pro如何隐藏软件)

  • 第一弹闪退(第一弹用不了了)

    第一弹闪退(第一弹用不了了)

  • 华为wifi信号带感叹号(华为wifi信号弱怎么办)

    华为wifi信号带感叹号(华为wifi信号弱怎么办)

  • 苹果融合硬盘什么意思(苹果的融合硬盘和固态区别)

    苹果融合硬盘什么意思(苹果的融合硬盘和固态区别)

  • 集线器和交换机有什么区别(集线器和交换机的作用)

    集线器和交换机有什么区别(集线器和交换机的作用)

  • 喷墨打印机打出来一条一条的(喷墨打印机打出来的字断断续续怎么解决)

    喷墨打印机打出来一条一条的(喷墨打印机打出来的字断断续续怎么解决)

  • 快手提现失败是怎么回事(快手提现失败是咋回事)

    快手提现失败是怎么回事(快手提现失败是咋回事)

  • 微信版本过低不能注册(微信版本过低不能登录怎么办)

    微信版本过低不能注册(微信版本过低不能登录怎么办)

  • mac换行快捷键(mac换行符)

    mac换行快捷键(mac换行符)

  • 华为p40双卡怎么放(华为p40双卡怎么切换打电话)

    华为p40双卡怎么放(华为p40双卡怎么切换打电话)

  • 微信怎么突然没有一个好友了(微信怎么突然没有提取文字功能)

    微信怎么突然没有一个好友了(微信怎么突然没有提取文字功能)

  • 微信能查出谁分享的名片吗(微信可以查出是谁吗)

    微信能查出谁分享的名片吗(微信可以查出是谁吗)

  • 根据ip地址能查上网记录吗(根据ip地址能查询哪些信息)

    根据ip地址能查上网记录吗(根据ip地址能查询哪些信息)

  • 苹果手机主题怎么设置成自己喜欢的(苹果手机主题怎么设置)

    苹果手机主题怎么设置成自己喜欢的(苹果手机主题怎么设置)

  • iphone7一直显示白苹果开不了机关不了机(iphone7一直显示无服务)

    iphone7一直显示白苹果开不了机关不了机(iphone7一直显示无服务)

  • 红米note8怎么打开红包助手(红米note8怎么打印)

    红米note8怎么打开红包助手(红米note8怎么打印)

  • Reno Ace怎么打开省电模式(reno ace rom)

    Reno Ace怎么打开省电模式(reno ace rom)

  • 快手小店上传商品要多久(快手小店上传商品提示需要上传产品资质)

    快手小店上传商品要多久(快手小店上传商品提示需要上传产品资质)

  • iphne怎么设置来电铃声(苹果手机如何设置来电模式)

    iphne怎么设置来电铃声(苹果手机如何设置来电模式)

  • win10诊断策略服务未运行(win10诊断策略服务未运行无法上网错误5)

    win10诊断策略服务未运行(win10诊断策略服务未运行无法上网错误5)

  • pr如何加音乐(pr如何加音乐视频)

    pr如何加音乐(pr如何加音乐视频)

  • iphone黑名单在哪里查找(苹果手机的黑名单在哪里看到)

    iphone黑名单在哪里查找(苹果手机的黑名单在哪里看到)

  • Linux下使用Speedtest测试网速的方法(linux 速度)

    Linux下使用Speedtest测试网速的方法(linux 速度)

  • vue-cli创建vue项目详细步骤(vue-cli4创建项目)

    vue-cli创建vue项目详细步骤(vue-cli4创建项目)

  • 印花税计算是含增值税吗
  • 工会筹备金的计税依据是应发工资还是实发工资
  • 公司注册资本认缴
  • 分期收款销售的基本业务处理
  • 企业弥补以前年度亏损顺序
  • 个体户的附加税表怎么填写
  • 实收资本在利润表中怎么体现出来
  • 非生产性费用不应计入产品成本
  • 未达起征点的税金如何做账
  • 税控技术维护费每年都能抵扣吗
  • 自产产品用于职工福利确认收入吗
  • 出口零退税率是什么意思
  • 运输费计入采购成本的分录
  • 库存材料盘亏会计分录
  • 个人所得税的税收优惠项目有哪些
  • 给职工发放的福利费,要从应付职工薪酬科目吗
  • 职工的大病医保怎么报销
  • 房屋租赁发票能抵扣几个点
  • 怎么分清楚待认识的人
  • 小规模纳税人一个季度多少免税
  • 购买税友系统可以抵扣吗
  • 无法读取金税盘时间版本怎么解决
  • 发生广告费用会计分录
  • 小规模纳税人优惠政策类型怎么选
  • 消费税为什么不计入长投成本
  • 工伤后辞职了还可以报工伤
  • 收到人民政府寄来的ems
  • 对公提回贷算收入吗
  • 旅游门票报销怎么算
  • 公司体检如何入账
  • 增值税小规模纳税人免征增值税政策
  • 委托贷款是流动资金贷款吗
  • 做研发费用需要什么条件
  • 产品入库的业务流程
  • 退回多扣的社保费给员工,怎样做会计分录?
  • 收取水电费如何开票
  • 0xc000007b应用程序无法正常启动win11
  • 网络和共享中心在哪里打开
  • 个人股权分红如何缴税
  • 冈山平原
  • 实收资本与注册资本之间的关系
  • 如何修改家里的wifi密码
  • 销售不动产税目计缴增值税有哪些
  • 福利费需要缴税吗
  • php调试函数
  • 驾校属于什么行业分类类别
  • vue项目启动过程
  • 物料最低库存
  • 融资租入固定资产的入账价值
  • 银行转账支付中是什么状态
  • 2022年山东省固定资产投资额
  • 下列行为免征增值税的有
  • 打印银行电子回单有断号
  • 长期股权投资的成本法和权益法区别
  • 短期借款如何记账
  • 可以抵扣进项税的项目包括
  • 凭证过账的步骤
  • 塑料制品厂设计
  • sql游标用法
  • xp系统如何打印文件
  • 苹果mac没有声音怎么办
  • macbook怎么开hdr
  • xp系统的存储在哪里
  • windows 8.1 with update
  • win10系统关闭安全中心
  • win10系统无法打开百度网盘
  • 小地图的主要作用是观察队友的大概位置
  • android开发手册
  • opengl和directX区别
  • linux进程管理命令使用
  • xbox无法连接无线网络
  • JavaScript中的复杂数据类型又称为
  • javascript教程
  • pythonnumpy报错
  • 国家税务统一代码查询
  • 定额发票网上查询
  • 黑龙江国税局官网
  • 药店迁址流程2019
  • 苏州相城离苏州市区有多远
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设