|
- """
- 添加了高斯噪声的自动编码器(DAE),用含有噪声的图片进行训练,和原图片对比求损失
-
- """
- import torch
- import torch.nn as nn
- from torch.utils.data import DataLoader
- from torchvision.utils import save_image
- from torchvision import datasets, transforms, models
- from UnionNet import Union_net
- import os
- # import matplotlib.pyplot as plt
- import numpy as np
-
- # from PIL import Image
-
- num_epoch = 30
-
-
- def add_noise(inputs, noise_factor=0.5):
- # 添加噪声
- noisy = inputs + torch.randn_like(inputs) * noise_factor
- noisy = torch.clip(noisy, 0., 1.)
- return noisy
-
-
- def add_noise_(inputs, noise_factor=0.75, mean=0.0, var=1.0):
- # 添加噪声
- noise = np.random.normal(mean, var, size=inputs.shape) * noise_factor
- noisy = inputs + noise
-
- # 将超过1的值置1.,低于0的置0.
- noisy = torch.clip(noisy, 0., 1.)
- noisy = torch.tensor(noisy, dtype=torch.float32)
- return noisy
-
-
- if __name__ == '__main__':
- trans = transforms.Compose([
- # AddGaussianNoise(mean=0, variance=1, amplitude=20),
- transforms.ToTensor(),
- ])
-
- train_mnist_data = datasets.MNIST(root='/tmp/dataset', transform=trans, train=True, download=False)
- test_mnist_data = datasets.MNIST(root='/tmp/dataset', transform=trans, train=False)
- train_loader = DataLoader(train_mnist_data, batch_size=128, shuffle=True) # data_size = (32,1,28,28)
- test_loader = DataLoader(test_mnist_data, batch_size=128, shuffle=False)
-
- trans = transforms.Compose([
- # AddGaussianNoise(mean=0, variance=1, amplitude=20),
- transforms.ToTensor(),
- ])
-
- if torch.cuda.is_available():
- device = torch.device("cuda")
- else:
- device = torch.device("cpu")
-
- if os.path.exists('./img/') is False:
- os.mkdir('./img')
-
- if not os.path.exists('./params'):
- os.mkdir('./params')
-
- net = Union_net().to(device)
-
- net.train()
- loss_ = nn.MSELoss()
- optim = torch.optim.Adam(net.parameters())
- print("==================start train==================")
- for epoch in range(num_epoch):
- for i, data in enumerate(train_loader):
- input_data, label = data
- noisy_data = add_noise_(input_data, mean=0, var=1, noise_factor=0.3)
- input_data = input_data.to(device)
- noisy_data = noisy_data.cuda(device)
- label = label.to(device)
- output_data = net(noisy_data)
- optim.zero_grad()
- loss = loss_(output_data, input_data)
- loss.backward()
- optim.step()
-
- if i % 128 == 0:
- print('Epoch[{}/{}], index:{} ,loss:{:.3f}'.format(epoch + 1, num_epoch, i, loss))
- save_path = '/tmp/output/net_' + str(epoch) +'.pth'
- torch.save(net.state_dict(), save_path)
- #
- # noisy_images = noisy_data.cpu()
- # save_image(noisy_images, './img/{}-noisy-images.jpg'.format(
- # (epoch + 1)
- # ), nrow=10)
- #
- # fake_images = output_data.cpu().data
- # save_image(fake_images, './img/{}-fake-images.jpg'.format(
- # (epoch + 1)
- # ), nrow=10)
- #
- # real_images = input_data.cpu()
- #
- # save_image(real_images, './img/{}-real-images.jpg'.format(
- # (epoch + 1)
- # ), nrow=10)
- #
-
-
- print("=============train end=================")
|