位置: IT常识 - 正文

pytorch 自编码器实现图像的降噪(pytorch自动编码器)

编辑:rootadmin
pytorch 自编码器实现图像的降噪 自编码器

推荐整理分享pytorch 自编码器实现图像的降噪(pytorch自动编码器),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:python编码工具,python自编码器,python自编码器,pytorch代码,pytorch自动编码器,pytorch 源码编译,pytorch onehot编码,pytorch 自编码器 异常检测,内容如对您有帮助,希望把文章链接给更多的朋友!

自动编码器是一种无监督的深度学习算法,它学习输入数据的编码表示,然后重新构造与输出相同的输入。它由编码器和解码器两个网络组成。编码器将高维输入压缩成低维潜在(也称为潜在代码或编码空间) ,以从中提取最相关的信息,而解码器则解压缩编码数据并重新创建原始输入。

自编码器的输入和输出应该尽可能的相似。

通过输入含有噪声的图像,编码器在编码的过程中会存在信息丢失,将输入和输出最相似的特征保留下来,通过解码器得到最后的输出。在这个转换的过程中实现了图像的去噪。

自编码器主要的用途其实是用于降维,将高维的数据编码为一组向量,解码器通过解码得到输出。

数据集导入可视化import torchvisionimport matplotlib.pyplot as pltfrom torch.utils.data import DataLoaderimport numpy as npimport randomimport PIL.Image as Imageimport torchvision.transforms as transformsclass AddPepperNoise(object): """增加椒盐噪声 Args: snr (float): Signal Noise Rate p (float): 概率值,依概率执行该操作 """ def __init__(self, snr, p=0.9): assert isinstance(snr, float) and (isinstance(p, float)) self.snr = snr self.p = p def __call__(self, img): """ Args: img (PIL Image): PIL Image Returns: PIL Image: PIL image. """ if random.uniform(0, 1) < self.p: img_ = np.array(img).copy() h, w = img_.shape signal_pct = self.snr noise_pct = (1 - self.snr) mask = np.random.choice((0, 1, 2), size=(h, w), p=[signal_pct, noise_pct/2., noise_pct/2.]) img_[mask == 1] = 255 # 盐噪声 img_[mask == 2] = 0 # 椒噪声 return Image.fromarray(img_.astype('uint8')) else: return imgclass Gaussian_noise(object): """增加高斯噪声 此函数用将产生的高斯噪声加到图片上 传入: img : 原图 mean : 均值 sigma : 标准差 返回: gaussian_out : 噪声处理后的图片 """ def __init__(self, mean, sigma): self.mean = mean self.sigma = sigma def __call__(self, img): """ Args: img (PIL Image): PIL Image Returns: PIL Image: PIL image. """ # 将图片灰度标准化 img_ = np.array(img).copy() img_ = img_ / 255.0 # 产生高斯 noise noise = np.random.normal(self.mean, self.sigma, img_.shape) # 将噪声和图片叠加 gaussian_out = img_ + noise # 将超过 1 的置 1,低于 0 的置 0 gaussian_out = np.clip(gaussian_out, 0, 1) # 将图片灰度范围的恢复为 0-255 gaussian_out = np.uint8(gaussian_out*255) # 将噪声范围搞为 0-255 # noise = np.uint8(noise*255) return Image.fromarray(gaussian_out)train_datasets = torchvision.datasets.MNIST('./', train=True, download=True)test_datasets = torchvision.datasets.MNIST('./', train=False, download=True)print('训练集的数量', len(train_datasets))print('测试集的数量', len(test_datasets))train_loader = DataLoader(train_datasets, batch_size=100, shuffle=True)test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)transform=transforms.Compose([ transforms.ToPILImage(), Gaussian_noise(0,0.1), AddPepperNoise(0.9) # transforms.ToTensor()])print('训练集可视化')fig = plt.figure()for i in range(12): plt.subplot(3, 4, i + 1) img = train_datasets.train_data[i] label = train_datasets.train_labels[i] # noise = np.random.normal(0.1, 0.1, img.shape) # img=transform(img) plt.imshow(img, cmap='gray') plt.title(label) plt.xticks([]) plt.yticks([])plt.show()

噪声图像 

 原始图像

模型的搭建import torchfrom torch import nnclass AE(nn.Module): def __init__(self): super(AE, self).__init__() # [b, 784] => [b, 20] self.encoder = nn.Sequential( nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 64), nn.ReLU(), nn.Linear(64, 20), nn.ReLU() ) # [b, 20] => [b, 784] self.decoder = nn.Sequential( nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 256), nn.ReLU(), nn.Linear(256, 784), nn.Sigmoid() ) def forward(self, x): """ :param x: [b, 1, 28, 28] :return: """ batchsz = x.size(0) # flatten(打平) x = x.view(batchsz, 784) # encoder x = self.encoder(x) # decoder x = self.decoder(x) # reshape x = x.view(batchsz, 1, 28, 28) return xif __name__=='__main__': model=AE() input=torch.randn(1,28,28) input=input.view(1,-1) print('输入的维度',input.shape) encoder_out=model.encoder(input) print('编码器的输出',encoder_out.shape) out=model.decoder(encoder_out) print('解码器的输出',out.shape)

 

模型的训练

导入训练集训练的时候一定要将使用transforms将所有图像转换为tensor格式,这里的方法不同于tensorflow导入MNIST方法,如果不加transforms则图像的格式为列表类型,下面在训练的时候会报错。

pytorch 自编码器实现图像的降噪(pytorch自动编码器)

在训练过程中添加噪声。分别添加了高斯噪声和椒盐噪声

import torchvisionfrom torch.utils.data import DataLoaderimport numpy as npimport random,osimport PIL.Image as Imageimport torchvision.transforms as transformsfrom torch import nn,optimimport torchfrom models import AEfrom tqdm import tqdmtrain_datasets = torchvision.datasets.MNIST('./', train=True, download=True,transform=transforms.ToTensor())test_datasets = torchvision.datasets.MNIST('./', train=False, download=True,transform=transforms.ToTensor())print('训练集的数量', len(train_datasets))print('测试集的数量', len(test_datasets))train_loader = DataLoader(train_datasets, batch_size=100, shuffle=True)test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)#模型,优化器,损失函数device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model=AE().to(device)criteon = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr=1e-3)##导入预训练模型if os.path.exists('./model.pth') : # 如果存在已保存的权重,则加载 checkpoint = torch.load('model.pth',map_location=lambda storage,loc:storage) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) initepoch = checkpoint['epoch'] loss = checkpoint['loss']else: initepoch=0#开始训练for epoch in range(initepoch, 50): with tqdm(total=(len(train_datasets)-len(train_datasets)), ncols=80) as t: t.set_description('epoch: {}/{}'.format(epoch, 50)) running_loss = 0.0 for i, data in enumerate(train_loader, 0): # get the inputs true_input, _ = data #生成均值为0,方差为0.1的高斯分布 gaussian_noise=torch.normal(mean=0,std=0.1,size=true_input.shape) image_noise=true_input+gaussian_noise noise_tensor = torch.rand(size=true_input.shape) #添加椒盐噪声 image_noise[noise_tensor<0.1]=0 #椒噪声 image_noise[noise_tensor > (1-0.1)] = 1 #盐噪声 #限制像素的范围在0-1之间 image_noise=torch.clamp(image_noise,min=0,max=1) optimizer.zero_grad() outputs = model(image_noise) loss = criteon(outputs, true_input) loss.backward() optimizer.step() running_loss += loss.item() t.set_postfix(trainloss='{:.6f}'.format(running_loss/len(train_loader))) t.update(len(true_input)) torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': running_loss/len(train_loader) }, 'model.pth')模型的测试

 在导入模型的时候经常发生上面的错误。模型在导入参数的时候不需要赋值操作。如果保存的方法是torch.load(model,'model.pth'),也就是直接保存模型的所有(包括模型的结构),在导入模型参数的时候可以使用model=torch.load('./model.pth')

import numpy as npimport torchvisionimport matplotlib.pyplot as pltfrom torch.utils.data import DataLoaderimport torchimport torchvision.transforms as transformsfrom models import AEfrom data import AddPepperNoise,Gaussian_noisetest_datasets = torchvision.datasets.MNIST('./', train=False, download=True)print('测试集的数量', len(test_datasets))test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)transform=transforms.Compose([ transforms.ToPILImage(), Gaussian_noise(0,0.2), AddPepperNoise(0.9), transforms.ToTensor()])model=AE()hh=torch.load('./model.pth',map_location=lambda storage,loc:storage)model.load_state_dict(hh['model_state_dict'])#错误写法# model=model.load_state_dict(hh['model_state_dict'])fig = plt.figure()for i in range(12): plt.subplot(3, 4, i + 1) img = test_datasets.train_data[i] label = test_datasets.test_labels[i] img_noise=transform(img) out=model(img_noise) out=out.squeeze() out=transforms.ToPILImage()(out) #原始图像,噪声图像,去噪图像 plt.imshow(np.hstack((np.array(img),np.array(transforms.ToPILImage()(img_noise)),np.array(out))), cmap='gray') plt.title(label) plt.xticks([]) plt.yticks([])plt.show()

 生成随机数看看解码器能解码出什么

生成标准正太分布import matplotlib.pyplot as pltimport torchimport torchvision.transforms as transformsfrom models import AEmodel=AE()model.load_state_dict(torch.load('./model.pth',map_location=lambda storage,loc:storage)['model_state_dict'])fig = plt.figure()for i in range(12): plt.subplot(3, 4, i + 1) input=torch.randn(1,20) out=model.decoder(input) out=out.view(28,28) out=transforms.ToPILImage()(out) plt.imshow(out, cmap='gray') plt.xticks([]) plt.yticks([])plt.show()

 生成0-1之间的均匀分布import matplotlib.pyplot as pltimport torchimport torchvision.transforms as transformsfrom models import AEmodel=AE()model.load_state_dict(torch.load('./model.pth',map_location=lambda storage,loc:storage)['model_state_dict'])fig = plt.figure()for i in range(12): plt.subplot(3, 4, i + 1) input=torch.rand(1,20) out=model.decoder(input) out=out.view(28,28) out=transforms.ToPILImage()(out) plt.imshow(out, cmap='gray') plt.xticks([]) plt.yticks([])plt.show()

可以看到随机生成的数据用解码器解码得到的数据都很乱。接下来,看看编码器编码后的数据服从什么分布。

看看编码器编码的输出服从什么分布import numpy as npimport torchvisionimport matplotlib.pyplot as pltfrom torch.utils.data import DataLoaderimport torchimport torchvision.transforms as transformsfrom models import AEfrom data import AddPepperNoise,Gaussian_noisetest_datasets = torchvision.datasets.MNIST('./', train=False, download=True)print('测试集的数量', len(test_datasets))test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)transform=transforms.Compose([ transforms.ToPILImage(), Gaussian_noise(0,0.2), AddPepperNoise(0.9), transforms.ToTensor()])model=AE()hh=torch.load('./model.pth',map_location=lambda storage,loc:storage)model.load_state_dict(hh['model_state_dict'])#错误写法# model=model.load_state_dict(hh['model_state_dict'])fig = plt.figure()for i in range(1): plt.subplot(1, 1, i + 1) img = test_datasets.test_data[i] label = test_datasets.test_labels[i] img_noise=transform(img) img_noise=img_noise.view(1,-1) out=model.encoder(img_noise) print('encoder的输出',out) #正太分布检验 import scipy.stats as stats print(stats.shapiro(out.detach().numpy())) plt.imshow(img, cmap='gray') plt.title(label) plt.xticks([]) plt.yticks([])plt.show()print('均值',torch.mean(out))print('方差',torch.var(out))

 可以看到一张图片7是服从均值为2.5,方差为8.55的正太分布的。

然后生成一些类似的分布看看效果。

import matplotlib.pyplot as pltimport torchimport torchvision.transforms as transformsfrom models import AEmodel=AE()model.load_state_dict(torch.load('./model.pth',map_location=lambda storage,loc:storage)['model_state_dict'])fig = plt.figure()for i in range(12): plt.subplot(3, 4, i + 1) input=torch.normal(mean=2.5928,std=8.5510,size=(1,20)) out=model.decoder(input) out=out.view(28,28) out=transforms.ToPILImage()(out) plt.imshow(out, cmap='gray') plt.xticks([]) plt.yticks([])plt.show()

 其实效果挺差的,可能是因为一张图片的分布并不能代表所有吧。

本文链接地址:https://www.jiuchutong.com/zhishi/298290.html 转载请保留说明!

上一篇:【数据集NO.1】最经典大规模、多样化的自动驾驶视频数据集——BDD100K数据集(数据集介绍)

下一篇:MySQL基本查询(mysqljoin查询)

  • ppt动画效果如何添加声音(ppt动画效果如何做出来)

    ppt动画效果如何添加声音(ppt动画效果如何做出来)

  • a卡和n卡有哪些区别(a卡和n卡有哪些不同)

    a卡和n卡有哪些区别(a卡和n卡有哪些不同)

  • 华为nova7pro电池是多大毫安(华为nova7pro电池多少钱一块)

    华为nova7pro电池是多大毫安(华为nova7pro电池多少钱一块)

  • 红米k30s至尊纪念版的屏幕刷新率是多少(红米k30s至尊纪念版为什么骂声一片)

    红米k30s至尊纪念版的屏幕刷新率是多少(红米k30s至尊纪念版为什么骂声一片)

  • 华为matebook14麦克风没声音是怎么回事(华为电脑mate book 14麦克风在哪)

    华为matebook14麦克风没声音是怎么回事(华为电脑mate book 14麦克风在哪)

  • 宽带错误711是什么意思(宽带显示错误711什么意思)

    宽带错误711是什么意思(宽带显示错误711什么意思)

  • Excel怎么做电子表格(电脑怎么做excel表格)

    Excel怎么做电子表格(电脑怎么做excel表格)

  • 苹果a1530是什么型号(苹果a1530是什么型号多少钱)

    苹果a1530是什么型号(苹果a1530是什么型号多少钱)

  • oppor17手机升级后怎么恢复旧版本(oppor17升级系统)

    oppor17手机升级后怎么恢复旧版本(oppor17升级系统)

  • 苹果手机指纹坏了能修吗(苹果手机指纹坏了是什么原因)

    苹果手机指纹坏了能修吗(苹果手机指纹坏了是什么原因)

  • 红米k30解锁方式(redmi k30s解锁方式)

    红米k30解锁方式(redmi k30s解锁方式)

  • b站电脑端可以缓存吗(B站电脑端可以打卡吗)

    b站电脑端可以缓存吗(B站电脑端可以打卡吗)

  • 拼多多跑单是什么意思(拼多多打单员难吗)

    拼多多跑单是什么意思(拼多多打单员难吗)

  • 联通暂时无法接通是什么意思

    联通暂时无法接通是什么意思

  • qq视频聊天可以录制吗(qq视频聊天可以截图吗)

    qq视频聊天可以录制吗(qq视频聊天可以截图吗)

  • 鼠标中间的按键是干嘛的(鼠标中间的按键是干嘛用的)

    鼠标中间的按键是干嘛的(鼠标中间的按键是干嘛用的)

  • 华为p30抬起亮屏怎么设置(华为p30抬起亮屏失效)

    华为p30抬起亮屏怎么设置(华为p30抬起亮屏失效)

  • 增强短信怎么收费(短信增强信息啥意思收费吗)

    增强短信怎么收费(短信增强信息啥意思收费吗)

  • 黑鲨2支持快充吗(黑鲨2快充支持多少)

    黑鲨2支持快充吗(黑鲨2快充支持多少)

  • 小米手环勿扰模式是什么意思(小米手环勿扰模式怎么设置)

    小米手环勿扰模式是什么意思(小米手环勿扰模式怎么设置)

  • 如何查看小米8屏幕材质(如何查看小米8手机电池损耗程度)

    如何查看小米8屏幕材质(如何查看小米8手机电池损耗程度)

  • 快手怎么下竞猜(快手的竞猜在哪里可以找到)

    快手怎么下竞猜(快手的竞猜在哪里可以找到)

  • 云骑士的系统是正版吗

    云骑士的系统是正版吗

  • 手机qq的安全中心在哪里打开(手机qq安全中心看不到登录足迹)

    手机qq的安全中心在哪里打开(手机qq安全中心看不到登录足迹)

  • cad加文字(cad加文字命令是什么)

    cad加文字(cad加文字命令是什么)

  • 小米手机锁屏后怎么显示时间(小米手机锁屏后右滑功能怎么关闭)

    小米手机锁屏后怎么显示时间(小米手机锁屏后右滑功能怎么关闭)

  • 企业所得税如何做分录
  • 合同金额含税不含税
  • 金融服务费可以谈吗
  • 增值税专票开完就扣税是吗
  • 厨房酒店用品
  • 发生哪些情形的应判定为重大电力安全隐患
  • 出口退税收入做什么科目
  • 农产品增值税进项税额核定扣除办法
  • 废弃土地的使用年限
  • 抵押贷款买房子合适还是商业贷款合适
  • 企业咨询评估
  • 用银行存款上交上月税金会计分录
  • 地税能不能查到个人的贷款行为?
  • 进项税额抵扣的情况有哪些
  • 误餐补助需要发票做账吗
  • 列举20种不征增值税产品
  • 宣传费开票属于什么费用
  • 小规模纳税人年度不超过500万
  • 建筑企业包工包料
  • 完工百分比法确认成本 分录
  • 这个月发票没用怎么做账
  • 企业所得税免税政策
  • 购买厂房可以一次买卖吗
  • 未缴少缴税款追征期
  • 生产成本中的电费计入制造费用吗
  • 初级会计实务的心得体会
  • 税控盘抵减
  • 财务费用利息收入怎么结转
  • 为什么捐赠还要交税
  • cpqdfwag.exe是什么进程 能结束吗 cpqdfwag进程查询
  • cnqmax.exe进程的详细注解 cnqmax进程是病毒吗 安全吗
  • p指针后移的语句
  • 最小型笔记本
  • 增值税发票认证在哪里
  • 关于php文件的自动播放
  • 计提企业所得税是在结转损益后吗
  • 包装费 增值税
  • php简单获取网站的方法
  • XF86Setup命令 设置XFee86
  • iozone测试结果分析
  • 微擎框架是开源的吗
  • 付了两次运费发了一个包裹
  • SQLite学习手册(SQLite在线备份)
  • 个人注册公司是否可以免税
  • 混合销售行为的例子
  • 一次性伤残补助金怎么查询进度
  • 纸质专票红冲
  • 个体户没有税务登记怎么开发票
  • 先收到发票还没付款怎么做账
  • 去年的分红奖金是多少
  • 未开票收入如何计提增值税
  • 工地零电零水布置图
  • 未分配利润是负数怎么消化掉
  • 外借资质交企业所得税怎么交
  • 购买土地的入账价值包括什么
  • 小规模年末怎么做账
  • 工程款预缴税
  • 货款和发票金额不一致
  • 本月未过账的凭证怎么写
  • 电子税务局自然人扣缴客户端
  • 小型润滑油生产设备要哪些
  • 工程发票可以抵扣增值税吗
  • 自然人股权转让如何缴纳个人所得税
  • WINDOWS命令行为什么删除速度很快
  • culauncherexe是什么进程
  • 怎么查显卡信息
  • windows的气泡屏保会加速
  • linux安装迅雷
  • linux ssh -v -p
  • win8.1系统更新
  • win7的命令对话框在哪里
  • chromexcel
  • python里面import
  • 写出perl中最常见4种控制流
  • windows批处理官方教程
  • 创建简单的Web网页实验总结ASP
  • python运行出现none
  • 第一章阎王点卯的小说名字
  • python调用ch
  • 外贸公司如何开发客户
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设