位置: IT常识 - 正文

Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成(diffusion扩散模型训练时间)

编辑:rootadmin
Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成 Diffusion扩散模型学习1——Pytorch搭建DDPM利用深度卷积神经网络实现图片生成学习前言源码下载地址网络构建一、什么是Diffusion1、加噪过程2、去噪过程二、DDPM网络的构建(Unet网络的构建)三、Diffusion的训练思路利用DDPM生成图片一、数据集的准备二、数据集的处理三、模型训练学习前言

推荐整理分享Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成(diffusion扩散模型训练时间),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:diffusion扩散模型是什么,diffusion扩散模型是哪一年的,diffusion扩散模型 nlp,diffusion扩散模型是哪一年的,diffusion扩散模型应用,diffusion扩散模型是哪一年的,diffusion扩散模型训练时间,diffusion扩散模型应用,内容如对您有帮助,希望把文章链接给更多的朋友!

我又死了我又死了我又死了!

源码下载地址

https://github.com/bubbliiiing/ddpm-pytorch

喜欢的可以点个star噢。

网络构建一、什么是Diffusion

如上图所示。DDPM模型主要分为两个过程: 1、Forward加噪过程(从右往左),数据集的真实图片中逐步加入高斯噪声,最终变成一个杂乱无章的高斯噪声,这个过程一般发生在训练的时候。加噪过程满足一定的数学规律。 2、Reverse去噪过程(从左往右),指对加了噪声的图片逐步去噪,从而还原出真实图片,这个过程一般发生在预测生成的时候。尽管在这里说的是加了噪声的图片,但实际去预测生成的时候,是随机生成一个高斯噪声来去噪。去噪的时候不断根据XtX_tXt​的图片生成Xt−1X_{t-1}Xt−1​的噪声,从而实现图片的还原。

1、加噪过程

Forward加噪过程主要符合如下的公式: xt=αtxt−1+1−αtz1x_t=\sqrt{\alpha_t} x_{t-1}+\sqrt{1-\alpha_t} z_{1}xt​=αt​​xt−1​+1−αt​​z1​ 其中αt\sqrt{\alpha_t}αt​​是预先设定好的超参数,被称为Noise schedule,通常是小于1的值,在论文中αt\alpha_tαt​的值从0.9999到0.998。ϵt−1∼N(,1)\epsilon_{t-1} \sim N(0, 1)ϵt−1​∼N(0,1)是高斯噪声。由公式(1)迭代推导。

xt=at(at−1xt−2+1−αt−1z2)+1−αtz1=atat−1xt−2+(at(1−αt−1)z2+1−αtz1)x_t=\sqrt{a_t}\left(\sqrt{a_{t-1}} x_{t-2}+\sqrt{1-\alpha_{t-1}} z_2\right)+\sqrt{1-\alpha_t} z_1=\sqrt{a_t a_{t-1}} x_{t-2}+\left(\sqrt{a_t\left(1-\alpha_{t-1}\right)} z_2+\sqrt{1-\alpha_t} z_1\right)xt​=at​​(at−1​​xt−2​+1−αt−1​​z2​)+1−αt​​z1​=at​at−1​​xt−2​+(at​(1−αt−1​)​z2​+1−αt​​z1​)

其中每次加入的噪声都服从高斯分布 z1,z2,…∼N(,1)z_1, z_2, \ldots \sim \mathcal{N}(0, 1)z1​,z2​,…∼N(0,1),两个高斯分布的相加高斯分布满足公式:N(,σ12)+N(,σ22)∼N(,(σ12+σ22))\mathcal{N}\left(0, \sigma_1^2 \right)+\mathcal{N}\left(0, \sigma_2^2 \right) \sim \mathcal{N}\left(0,\left(\sigma_1^2+\sigma_2^2\right) \right)N(0,σ12​)+N(0,σ22​)∼N(0,(σ12​+σ22​)),因此,得到xtx_txt​的公式为: xt=atat−1xt−2+1−αtαt−1z2x_t = \sqrt{a_t a_{t-1}} x_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} z_2xt​=at​at−1​​xt−2​+1−αt​αt−1​​z2​ 因此不断往里面套,就能发现规律了,其实就是累乘 可以直接得出xx_0x0​到xtx_txt​的公式: xt=αt‾x+1−αt‾ztx_t=\sqrt{\overline{\alpha_t}} x_0+\sqrt{1-\overline{\alpha_t}} z_txt​=αt​​​x0​+1−αt​​​zt​

其中αt‾=∏itαi\overline{\alpha_t}=\prod_i^t \alpha_iαt​​=∏it​αi​,这是随Noise schedule设定好的超参数,zt−1∼N(,1)z_{t-1} \sim N(0, 1)zt−1​∼N(0,1)也是一个高斯噪声。通过上述两个公式,我们可以不断的将图片进行破坏加噪。

2、去噪过程

反向过程就是通过估测噪声,多次迭代逐渐将被破坏的xtx_txt​恢复成xx_0x0​,在恢复时刻,我们已经知道的是xtx_txt​,这是图片在ttt时刻的噪声图。一下子从xtx_txt​恢复成xx_0x0​是不可能的,我们只能一步一步的往前推,首先从xtx_txt​恢复成xt−1x_{t-1}xt−1​。根据贝叶斯公式,已知xtx_txt​反推xt−1x_{t-1}xt−1​: q(xt−1∣xt,x)=q(xt∣xt−1,x)q(xt−1∣x)q(xt∣x)q\left(x_{t-1} \mid x_t, x_0\right)=q\left(x_t \mid x_{t-1}, x_0\right) \frac{q\left(x_{t-1} \mid x_0\right)}{q\left(x_t \mid x_0\right)}q(xt−1​∣xt​,x0​)=q(xt​∣xt−1​,x0​)q(xt​∣x0​)q(xt−1​∣x0​)​ 右边的三个东西都可以从x_0开始推得到: q(xt−1∣x)=aˉt−1x+1−aˉt−1z∼N(aˉt−1x,1−aˉt−1)q\left(x_{t-1} \mid x_0\right)=\sqrt{\bar{a}_{t-1}} x_0+\sqrt{1-\bar{a}_{t-1}} z \sim \mathcal{N}\left(\sqrt{\bar{a}_{t-1}} x_0, 1-\bar{a}_{t-1}\right)q(xt−1​∣x0​)=aˉt−1​​x0​+1−aˉt−1​​z∼N(aˉt−1​​x0​,1−aˉt−1​) q(xt∣x)=aˉtx+1−αˉtz∼N(aˉtx,1−αˉt)q\left(x_t \mid x_0\right) = \sqrt{\bar{a}_t} x_0+\sqrt{1-\bar{\alpha}_t} z \sim \mathcal{N}\left(\sqrt{\bar{a}_t} x_0 , 1-\bar{\alpha}_t\right)q(xt​∣x0​)=aˉt​​x0​+1−αˉt​​z∼N(aˉt​​x0​,1−αˉt​) q(xt∣xt−1,x)=atxt−1+1−αtz∼N(atxt−1,1−αt)q\left(x_t \mid x_{t-1}, x_0\right)=\sqrt{a_t} x_{t-1}+\sqrt{1-\alpha_t} z \sim \mathcal{N}\left(\sqrt{a_t} x_{t-1}, 1-\alpha_t\right) \\q(xt​∣xt−1​,x0​)=at​​xt−1​+1−αt​​z∼N(at​​xt−1​,1−αt​) 因此,由于右边三个东西均满足正态分布,q(xt−1∣xt,x)q\left(x_{t-1} \mid x_t, x_0\right)q(xt−1​∣xt​,x0​)满足分布如下: ∝exp⁡(−12((xt−αtxt−1)2βt+(xt−1−αˉt−1x)21−αˉt−1−(xt−αˉtx)21−αˉt))\propto \exp \left(-\frac{1}{2}\left(\frac{\left(x_t-\sqrt{\alpha_t} x_{t-1}\right)^2}{\beta_t}+\frac{\left(x_{t-1}-\sqrt{\bar{\alpha}_{t-1}} x_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(x_t-\sqrt{\bar{\alpha}_t} x_0\right)^2}{1-\bar{\alpha}_t}\right)\right)∝exp(−21​(βt​(xt​−αt​​xt−1​)2​+1−αˉt−1​(xt−1​−αˉt−1​​x0​)2​−1−αˉt​(xt​−αˉt​​x0​)2​)) 把标准正态分布展开后,乘法就相当于加,除法就相当于减,把他们汇总 接下来继续化简,咱们现在要求的是上一时刻的分布 ∝exp⁡(−12((xt−αtxt−1)2βt+(xt−1−αˉt−1x)21−αˉt−1−(xt−αˉtx)21−αˉt))=exp⁡(−12(xt2−2αtxtxt−1+αtxt−12βt+xt−12−2αˉt−1xxt−1+αˉt−1x21−αˉt−1−(xt−αˉtx)21−αˉt))=exp⁡(−12((αtβt+11−αˉt−1)xt−12−(2αtβtxt+2αˉt−11−αˉt−1x)xt−1+C(xt,x)))\begin{aligned} & \propto \exp \left(-\frac{1}{2}\left(\frac{\left(x_t-\sqrt{\alpha_t} x_{t-1}\right)^2}{\beta_t}+\frac{\left(x_{t-1}-\sqrt{\bar{\alpha}_{t-1}} x_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(x_t-\sqrt{\bar{\alpha}_t} x_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\frac{x_t^2-2 \sqrt{\alpha_t} x_t x_{t-1}+\alpha_t x_{t-1}^2}{\beta_t}+\frac{x_{t-1}^2-2 \sqrt{\bar{\alpha}_{t-1}} x_0 x_{t-1}+\bar{\alpha}_{t-1} x_0^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(x_t-\sqrt{\bar{\alpha}_t} x_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) x_{t-1}^2-\left(\frac{2 \sqrt{\alpha_t}}{\beta_t} x_t+\frac{2 \sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}} x_0\right) x_{t-1}+C\left(x_t, x_0\right)\right)\right) \end{aligned}​∝exp(−21​(βt​(xt​−αt​​xt−1​)2​+1−αˉt−1​(xt−1​−αˉt−1​​x0​)2​−1−αˉt​(xt​−αˉt​​x0​)2​))=exp(−21​(βt​xt2​−2αt​​xt​xt−1​+αt​xt−12​​+1−αˉt−1​xt−12​−2αˉt−1​​x0​xt−1​+αˉt−1​x02​​−1−αˉt​(xt​−αˉt​​x0​)2​))=exp(−21​((βt​αt​​+1−αˉt−1​1​)xt−12​−(βt​2αt​​​xt​+1−αˉt−1​2αˉt−1​​​x0​)xt−1​+C(xt​,x0​)))​ 正态分布满足公式,exp⁡(−(x−μ)22σ2)=exp⁡(−12(1σ2x2−2μσ2x+μ2σ2))\exp \left(-\frac{(x-\mu)^2}{2 \sigma^2}\right)=\exp \left(-\frac{1}{2}\left(\frac{1}{\sigma^2} x^2-\frac{2 \mu}{\sigma^2} x+\frac{\mu^2}{\sigma^2}\right)\right)exp(−2σ2(x−μ)2​)=exp(−21​(σ21​x2−σ22μ​x+σ2μ2​)),其中σ\sigmaσ就是方差,μ\muμ就是均值,配方后我们就可以获得均值和方差。

Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成(diffusion扩散模型训练时间)

此时的均值为:μ~t(xt,x)=αt(1−αˉt−1)1−αˉtxt+αˉt−1βt1−αˉtx\tilde{\mu}_t\left(x_t, x_0\right)=\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} x_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} x_0μ~​t​(xt​,x0​)=1−αˉt​αt​​(1−αˉt−1​)​xt​+1−αˉt​αˉt−1​​βt​​x0​。根据之前的公式,xt=αt‾x+1−αt‾ztx_t=\sqrt{\overline{\alpha_t}} x_0+\sqrt{1-\overline{\alpha_t}} z_txt​=αt​​​x0​+1−αt​​​zt​,我们可以使用xtx_txt​反向估计xx_0x0​得到xx_0x0​满足分布x=1αˉt(xt−1−αˉtzt)x_0=\frac{1}{\sqrt{\bar{\alpha}_t}}\left(\mathrm{x}_t-\sqrt{1-\bar{\alpha}_t} z_t\right)x0​=αˉt​​1​(xt​−1−αˉt​​zt​)。最终得到均值为μ~t=1at(xt−βt1−aˉtzt)\tilde{\mu}_t=\frac{1}{\sqrt{a_t}}\left(x_t-\frac{\beta_t}{\sqrt{1-\bar{a}_t}} z_t\right)μ~​t​=at​​1​(xt​−1−aˉt​​βt​​zt​) ,ztz_tzt​代表t时刻的噪音是什么。由ztz_tzt​无法直接获得,网络便通过当前时刻的xtx_txt​经过神经网络计算ztz_tzt​。ϵθ(xt,t)\epsilon_\theta\left(x_t, t\right)ϵθ​(xt​,t)也就是上面提到的ztz_tzt​。ϵθ\epsilon_\thetaϵθ​代表神经网络。 xt−1=1αt(xt−1−αt1−αˉtϵθ(xt,t))+σtzx_{t-1}=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta\left(x_t, t\right)\right)+\sigma_t zxt−1​=αt​​1​(xt​−1−αˉt​​1−αt​​ϵθ​(xt​,t))+σt​z 由于加噪过程中的真实噪声ϵ\epsilonϵ在复原过程中是无法获得的,因此DDPM的关键就是训练一个由xtx_txt​和ttt估测橾声的模型 ϵθ(xt,t)\epsilon_\theta\left(x_t, t\right)ϵθ​(xt​,t),其中θ\thetaθ就是模型的训练参数,σt\sigma_tσt​ 也是一个高斯噪声 σt∼N(,1)\sigma_t \sim N(0,1)σt​∼N(0,1),用于表示估测与实际的差距。在DDPM中,使用U-Net作为估测噪声的模型。

本质上,我们就是训练这个Unet模型,该模型输入为xtx_txt​和ttt,输出为xtx_txt​时刻的高斯噪声。即利用xtx_txt​和ttt预测这一时刻的高斯噪声。这样就可以一步一步的再从噪声回到真实图像。

二、DDPM网络的构建(Unet网络的构建)

上图是典型的Unet模型结构,仅仅作为示意图,里面具体的数字同学们无需在意,和本文的学习无关。在本文中,Unet的输入和输出shape相同,通道均为3(一般为RGB三通道),宽高相同。

本质上,DDPM最重要的工作就是训练Unet模型,该模型输入为xtx_txt​和ttt,输出为xt−1x_{t-1}xt−1​时刻的高斯噪声。即利用xtx_txt​和ttt预测上一时刻的高斯噪声。这样就可以一步一步的再从噪声回到真实图像。

假设我们需要生成一个[64, 64, 3]的图像,在ttt时刻,我们有一个xtx_txt​噪声图,该噪声图的的shape也为[64, 64, 3],我们将它和ttt一起输入到Unet中。Unet的输出为xt−1x_{t-1}xt−1​时刻的[64, 64, 3]的噪声。

实现代码如下,代码中的特征提取模块为残差结构,方便优化:

import mathimport torchimport torch.nn as nnimport torch.nn.functional as Fdef get_norm(norm, num_channels, num_groups): if norm == "in": return nn.InstanceNorm2d(num_channels, affine=True) elif norm == "bn": return nn.BatchNorm2d(num_channels) elif norm == "gn": return nn.GroupNorm(num_groups, num_channels) elif norm is None: return nn.Identity() else: raise ValueError("unknown normalization type")#------------------------------------------## 计算时间步长的位置嵌入。# 一半为sin,一半为cos。#------------------------------------------#class PositionalEmbedding(nn.Module): def __init__(self, dim, scale=1.0): super().__init__() assert dim % 2 == 0 self.dim = dim self.scale = scale def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / half_dim emb = torch.exp(torch.arange(half_dim, device=device) * -emb) # x * self.scale和emb外积 emb = torch.outer(x * self.scale, emb) emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb#------------------------------------------## 下采样层,一个步长为2x2的卷积#------------------------------------------#class Downsample(nn.Module): def __init__(self, in_channels): super().__init__() self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1) def forward(self, x, time_emb, y): if x.shape[2] % 2 == 1: raise ValueError("downsampling tensor height should be even") if x.shape[3] % 2 == 1: raise ValueError("downsampling tensor width should be even") return self.downsample(x)#------------------------------------------## 上采样层,Upsample+卷积#------------------------------------------#class Upsample(nn.Module): def __init__(self, in_channels): super().__init__() self.upsample = nn.Sequential( nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(in_channels, in_channels, 3, padding=1), ) def forward(self, x, time_emb, y): return self.upsample(x)#------------------------------------------## 使用Self-Attention注意力机制# 做一个全局的Self-Attention#------------------------------------------#class AttentionBlock(nn.Module): def __init__(self, in_channels, norm="gn", num_groups=32): super().__init__() self.in_channels = in_channels self.norm = get_norm(norm, in_channels, num_groups) self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1) self.to_out = nn.Conv2d(in_channels, in_channels, 1) def forward(self, x): b, c, h, w = x.shape q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1) q = q.permute(0, 2, 3, 1).view(b, h * w, c) k = k.view(b, c, h * w) v = v.permute(0, 2, 3, 1).view(b, h * w, c) dot_products = torch.bmm(q, k) * (c ** (-0.5)) assert dot_products.shape == (b, h * w, h * w) attention = torch.softmax(dot_products, dim=-1) out = torch.bmm(attention, v) assert out.shape == (b, h * w, c) out = out.view(b, h, w, c).permute(0, 3, 1, 2) return self.to_out(out) + x#------------------------------------------## 用于特征提取的残差结构#------------------------------------------#class ResidualBlock(nn.Module): def __init__( self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=F.relu, norm="gn", num_groups=32, use_attention=False, ): super().__init__() self.activation = activation self.norm_1 = get_norm(norm, in_channels, num_groups) self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) self.norm_2 = get_norm(norm, out_channels, num_groups) self.conv_2 = nn.Sequential( nn.Dropout(p=dropout), nn.Conv2d(out_channels, out_channels, 3, padding=1), ) self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None self.residual_connection = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() self.attention = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups) def forward(self, x, time_emb=None, y=None): out = self.activation(self.norm_1(x)) # 第一个卷积 out = self.conv_1(out) # 对时间time_emb做一个全连接,施加在通道上 if self.time_bias is not None: if time_emb is None: raise ValueError("time conditioning was specified but time_emb is not passed") out += self.time_bias(self.activation(time_emb))[:, :, None, None] # 对种类y_emb做一个全连接,施加在通道上 if self.class_bias is not None: if y is None: raise ValueError("class conditioning was specified but y is not passed") out += self.class_bias(y)[:, :, None, None] out = self.activation(self.norm_2(out)) # 第二个卷积+残差边 out = self.conv_2(out) + self.residual_connection(x) # 最后做个Attention out = self.attention(out) return out#------------------------------------------## Unet模型#------------------------------------------#class UNet(nn.Module): def __init__( self, img_channels, base_channels=128, channel_mults=(1, 2, 2, 2), num_res_blocks=2, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=F.silu, dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0, ): super().__init__() # 使用到的激活函数,一般为SILU self.activation = activation # 是否对输入进行padding self.initial_pad = initial_pad # 需要去区分的类别数 self.num_classes = num_classes # 对时间轴输入的全连接层 self.time_mlp = nn.Sequential( PositionalEmbedding(base_channels, time_emb_scale), nn.Linear(base_channels, time_emb_dim), nn.SiLU(), nn.Linear(time_emb_dim, time_emb_dim), ) if time_emb_dim is not None else None # 对输入图片的第一个卷积 self.init_conv = nn.Conv2d(img_channels, base_channels, 3, padding=1) # self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征 # 然后利用Downsample降低特征图的高宽 self.downs = nn.ModuleList() self.ups = nn.ModuleList() # channels指的是每一个模块处理后的通道数 # now_channels是一个中间变量,代表中间的通道数 channels = [base_channels] now_channels = base_channels for i, mult in enumerate(channel_mults): out_channels = base_channels * mult for _ in range(num_res_blocks): self.downs.append( ResidualBlock( now_channels, out_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions, ) ) now_channels = out_channels channels.append(now_channels) if i != len(channel_mults) - 1: self.downs.append(Downsample(now_channels)) channels.append(now_channels) # 可以看作是特征整合,中间的一个特征提取模块 self.mid = nn.ModuleList( [ ResidualBlock( now_channels, now_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups, use_attention=True, ), ResidualBlock( now_channels, now_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups, use_attention=False, ), ] ) # 进行上采样,进行特征融合 for i, mult in reversed(list(enumerate(channel_mults))): out_channels = base_channels * mult for _ in range(num_res_blocks + 1): self.ups.append(ResidualBlock( channels.pop() + now_channels, out_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions, )) now_channels = out_channels if i != 0: self.ups.append(Upsample(now_channels)) assert len(channels) == 0 self.out_norm = get_norm(norm, base_channels, num_groups) self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1) def forward(self, x, time=None, y=None): # 是否对输入进行padding ip = self.initial_pad if ip != 0: x = F.pad(x, (ip,) * 4) # 对时间轴输入的全连接层 if self.time_mlp is not None: if time is None: raise ValueError("time conditioning was specified but tim is not passed") time_emb = self.time_mlp(time) else: time_emb = None if self.num_classes is not None and y is None: raise ValueError("class conditioning was specified but y is not passed") # 对输入图片的第一个卷积 x = self.init_conv(x) # skips用于存放下采样的中间层 skips = [x] for layer in self.downs: x = layer(x, time_emb, y) skips.append(x) # 特征整合与提取 for layer in self.mid: x = layer(x, time_emb, y) # 上采样并进行特征融合 for layer in self.ups: if isinstance(layer, ResidualBlock): x = torch.cat([x, skips.pop()], dim=1) x = layer(x, time_emb, y) # 上采样并进行特征融合 x = self.activation(self.out_norm(x)) x = self.out_conv(x) if self.initial_pad != 0: return x[:, :, ip:-ip, ip:-ip] else: return x三、Diffusion的训练思路

Diffusion的训练思路比较简单,首先随机给每个batch里每张图片都生成一个t,代表我选择这个batch里面第t个时刻的噪声进行拟合。代码如下:

t = torch.randint(0, self.num_timesteps, (b,), device=device)

生成batch_size个噪声,计算施加这个噪声后模型在t个时刻的噪声图片是怎么样的,如下所示:

def perturb_x(self, x, t, noise): return ( extract(self.sqrt_alphas_cumprod, t, x.shape) * x + extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise ) def get_losses(self, x, t, y): # x, noise [batch_size, 3, 64, 64] noise = torch.randn_like(x) perturbed_x = self.perturb_x(x, t, noise)

之后利用这个噪声图片、t和网络模型计算预测噪声,利用预测噪声和实际噪声进行拟合。

def get_losses(self, x, t, y): # x, noise [batch_size, 3, 64, 64] noise = torch.randn_like(x) perturbed_x = self.perturb_x(x, t, noise) estimated_noise = self.model(perturbed_x, t, y) if self.loss_type == "l1": loss = F.l1_loss(estimated_noise, noise) elif self.loss_type == "l2": loss = F.mse_loss(estimated_noise, noise) return loss利用DDPM生成图片

DDPM的库整体结构如下:

一、数据集的准备

在训练前需要准备好数据集,数据集保存在datasets文件夹里面。

二、数据集的处理

打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。 此时生成根目录下面的train_lines.txt。

三、模型训练

在完成数据集处理后,运行train.py即可开始训练。 训练过程中,可在results文件夹内查看训练效果:

本文链接地址:https://www.jiuchutong.com/zhishi/299194.html 转载请保留说明!

上一篇:图表库-Echarts(图表库网站)

下一篇:新一代 L1 公链Aptos:安全、可扩展和可升级的Web3基础设施 |Tokenview(公链dapp)

  • 礼仪在邮件群发中的有多重要(礼仪在邮件群发中的作用)

    礼仪在邮件群发中的有多重要(礼仪在邮件群发中的作用)

  • xml文件转化为excel(xml文件转化为excel数据变了)

    xml文件转化为excel(xml文件转化为excel数据变了)

  • oppo reno 4se支持NFC功能吗(oppo reno4支持nfc吗?)

    oppo reno 4se支持NFC功能吗(oppo reno4支持nfc吗?)

  • 抖音怎么下载音乐到本地(抖音怎么下载音频)

    抖音怎么下载音乐到本地(抖音怎么下载音频)

  • 3t硬盘实际容量是多少(3t硬盘容量为什么只有760g)

    3t硬盘实际容量是多少(3t硬盘容量为什么只有760g)

  • Word页眉怎么设置横线(word页眉怎么设置页码连续)

    Word页眉怎么设置横线(word页眉怎么设置页码连续)

  • 苹果6长宽高是多少厘米(苹果6手机长度和宽度各是多少?)

    苹果6长宽高是多少厘米(苹果6手机长度和宽度各是多少?)

  • 小爱音箱一直插电会烧坏吗(小爱音箱一直插电发热)

    小爱音箱一直插电会烧坏吗(小爱音箱一直插电发热)

  • 华为m6支持多屏协同吗(华为m6双屏怎么用)

    华为m6支持多屏协同吗(华为m6双屏怎么用)

  • 手机摔地上没有坏会不会有影响(手机摔地上没有声音怎么回事)

    手机摔地上没有坏会不会有影响(手机摔地上没有声音怎么回事)

  • 竖屏照片怎么变横屏(竖屏照片怎么变横屏铺满)

    竖屏照片怎么变横屏(竖屏照片怎么变横屏铺满)

  • mate30pro怎么调音量(华为mate30prozm怎么调音量)

    mate30pro怎么调音量(华为mate30prozm怎么调音量)

  • 为什么b站下载的视频和音频是分开的(为什么b站下载的视频打不开)

    为什么b站下载的视频和音频是分开的(为什么b站下载的视频打不开)

  • 手机还原设置是什么意思(手机还原设置是否包括更新的天气预报软件)

    手机还原设置是什么意思(手机还原设置是否包括更新的天气预报软件)

  • 华为mate30pro有没有来电闪光灯(华为mate30pro有没有无线充电功能)

    华为mate30pro有没有来电闪光灯(华为mate30pro有没有无线充电功能)

  • ps怎么抠字不要背景

    ps怎么抠字不要背景

  • word文档英文字体更改(word文档英文字母下面有红线)

    word文档英文字体更改(word文档英文字母下面有红线)

  • 手机wps表格怎么换行(手机wps表格怎么求和)

    手机wps表格怎么换行(手机wps表格怎么求和)

  • 手机卡显示hd怎样设置(手机卡显示hd怎么取消)

    手机卡显示hd怎样设置(手机卡显示hd怎么取消)

  • 闲鱼怎么设置登录密码(闲鱼怎么设置登录隐藏)

    闲鱼怎么设置登录密码(闲鱼怎么设置登录隐藏)

  • qq发不出图片什么原因(qq发不出图片什么原因iOS)

    qq发不出图片什么原因(qq发不出图片什么原因iOS)

  • ca数字证书是什么东西(ca数字证书是什么时候取消的)

    ca数字证书是什么东西(ca数字证书是什么时候取消的)

  • 手机qq正在下载怎么取消(手机qq正在下载安装)

    手机qq正在下载怎么取消(手机qq正在下载安装)

  • vivoy93什么时候上市的(vivoy93什么时候生产的)

    vivoy93什么时候上市的(vivoy93什么时候生产的)

  • 小米8如何分屏(小米如何分屏操作步骤)

    小米8如何分屏(小米如何分屏操作步骤)

  • 营业收入与利润变化图
  • 工会经费如何申报?
  • 预付加油卡发票可以报销吗
  • 承包费收入如何入账
  • 税务局代扣代缴税费
  • 企业所得税季度预缴可以弥补以前年度亏损吗
  • 小规模纳税人帮别人报关
  • 免税企业开了含税发票
  • 增值税2017年起征点
  • 购买金税盘需要法人去税务局进行信息采集吗
  • 关于不动产进项税额分期抵扣的新政策,以下不属于
  • 留存收益账务处理视频
  • 企业的不征税收入用于支出所形成的资产,其计算的折旧
  • 用钱买的代金券怎么使用
  • 抵税的税额怎么计算
  • 兼职劳务费个税怎么算
  • 赔偿给客户的钱抵扣货款的会计分录
  • 应收账款坏账准备是信用减值损失还是资产
  • 全免增值税企业有哪些
  • 其他应收账款怎么算
  • 股东大会的召集有权
  • 冲未开票收入怎么做分录
  • 员工意外伤害保险怎么买
  • linux系统中配置网卡ip地址的命令为
  • win10 kb5001567
  • win10待机界面进不去系统怎么办
  • 非正常损失运输费进项税额如何转出
  • vue做移动端适配最佳解决方案,亲测有效
  • 购买货物现金付讫的会计分录
  • opera software
  • php 设计模式
  • 补缴增值税和滞纳税区别
  • 权限控制的原则是什么
  • 预算会计与财务会计适度分离
  • Vision Transformer 模型详解
  • 最新前端面试题
  • php上传文件限制大小
  • 委托代销商品委托方和受托方会计分录
  • 织梦网站老是被挂马
  • 缴纳税款滞纳金怎么算
  • 印花税按什么的比例缴纳
  • 对公账户分类及区别
  • 应收票据贴现的实收金额一定小于票据面值
  • 资金账簿印花税按年还是按次
  • 车辆购置税可以网上缴纳吗
  • sql server数据查询语句
  • 企业所得税多预缴了怎么办
  • 公司首次申报个人所得税
  • 企业所得税季度申报数据怎么来
  • 研发费用明细科目怎么填
  • 按信用风险特征组合
  • 搬迁补偿款的会计分录
  • 长期应付款的会计编号
  • 填制费用报销单怎么填写
  • 土地使用权如何计入房产原值交房产税时间
  • 工资代扣工会会费协议
  • 办理分公司的流程样本
  • 股权变更需要交哪些税
  • mysql怎么复制粘贴语句
  • 终端运行mysql
  • docker mysql 数据
  • mysql5.5.62安装配置教程
  • win sth
  • 系统 启动速度慢怎么办
  • supporter5.exe - supporter5是什么进程
  • w7打穿越火线
  • xp如何升级到sp3
  • WIN10更新WIN11卡在63%
  • win8怎么创建宽带连接
  • win7怎连蓝牙
  • node.js gui
  • jquery validate表单内容怎么添加边框
  • angular 图片懒加载
  • python抢红包
  • unity灯光闪烁效果
  • javascript零基础入门
  • python生产
  • 税务局分类分级
  • 叶青和奚卫华
  • 房地产增值税怎么算举例说明
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设