|
- import argparse
- import os
-
- import numpy as np
- import torch
- import torch.nn as nn
- from PIL import Image
- from torch.autograd import Variable
- from torch.utils.data import DataLoader, Dataset
- from torchvision import transforms
- from torchvision.datasets import ImageFolder
- import Model.model as model
- from Model.context_model import Weighted_Gaussian
- from Model.basic_module import anchor
- import time
- from traindata import Traindataset
-
- def penalty(input_,a,b):
- m = (b+a)/2
- radio = (b-a)/2
- t = (input_-m)/radio
- return t**8
-
- def adjust_learning_rate(optimizer, epoch, init_lr):
-
- if epoch < 10:
- lr = init_lr
- else:
- lr = init_lr * (0.5 ** ((epoch-7) // 3))
- if lr < 1e-6:
- lr = 1e-6
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
- return lr
-
- def train(args):
-
- root = '../../../userdata'
-
- train_data = Traindataset(root)
- print(len(train_data))
- train_loader = DataLoader(train_data, batch_size=args.b_size,
- shuffle=True, num_workers=8)
-
- if args.fc==0:
- model_existed = os.path.exists(os.path.join(args.out_dir,'mse64.pkl'))
- whole = model.Image_coding(3, args.M, args.N2, args.M, args.M // 2).cuda()
- else :
- model_existed = os.path.exists(os.path.join(args.out_dir,'mse64lfc.pkl'))
- whole = model.Image_coding_lfc(3, args.M, args.N2, args.M, args.M // 2).cuda()
- #base = model.Image_coding(3, args.M, args.N2, args.M, args.M // 2).cuda()
-
- if model_existed:
- if args.fc==0:
- whole.load_state_dict(torch.load(os.path.join(args.out_dir,'mse64.pkl')))
- else:
- whole.load_state_dict(torch.load(os.path.join(args.out_dir,'mse64lfc.pkl')))
-
- print('previous model resumed')
- else:
- print('no previous model')
-
- gpu_id = [id for id in range(args.gpu)]
- whole = nn.DataParallel(whole, device_ids=gpu_id)
-
- loss_func = nn.MSELoss()
-
- opt = torch.optim.Adam(whole.parameters(), lr=args.lr)
-
- for epoch in range(10):
- rec_loss_tmp = 0
- last_time = time.time()
- train_bpp_tmp = 0
- mse_tmp = 0
- cur_lr = adjust_learning_rate(opt, epoch, args.lr)
-
- for step, batch_x in enumerate(train_loader):
-
- num_pixels = batch_x.size()[0] * \
- batch_x.size()[2] * batch_x.size()[3]
- batch_x = Variable(batch_x).cuda()
- b = batch_x.size()[0]
-
- fake,xp2,xp3=whole(batch_x,1)
-
- train_bpp_total = torch.sum(torch.log(xp2)) / (-np.log(2) * num_pixels) + torch.sum(torch.log(xp3)) / (
- -np.log(2) * num_pixels)
- dloss = loss_func(fake,batch_x)
-
- #diff = args.br * dloss - anchor.apply(train_bpp_total) #实验表明br=1,2,3,4,5差不多 2021.7.27
- #print(train_bpp_total.item(),diff.item(),dloss.item())
-
- l_rec = args.lmd * dloss + 0.01 * train_bpp_total
-
- opt.zero_grad()
-
- l_rec.backward()
-
- torch.nn.utils.clip_grad_norm_(whole.parameters(), 5)
-
- opt.step()
-
- rec_loss_tmp += l_rec.item()
- train_bpp_tmp+=train_bpp_total.item()
- mse_tmp += torch.mean(dloss).item()
-
- if step % 100 == 0:
-
- time_used = time.time() - last_time
- last_time = time.time()
- mse = mse_tmp/(1+step)
- psnr = 10*np.log10(1.0/mse)
- bpp_total = train_bpp_tmp / (1+step)
- print(
- 'ep:%d step:%d time:%.1f lr:%.8f loss:%.6f bpp:%.4f psnr_ave:%.2f\n'
- % (epoch, step, time_used,cur_lr,rec_loss_tmp/(step+1), bpp_total , psnr))
-
- if args.fc==0:
- with open(os.path.join(args.out_dir, 'mse'+str(int(args.lmd))+'.txt'), 'a') as fd:
- fd.write(
- 'ep:%d step:%d time:%.1f lr:%.8f loss:%.6f bpp:%.4f psnr_ave:%.2f\n'
- % (epoch, step, time_used,cur_lr,rec_loss_tmp/(step+1), bpp_total , psnr))
- fd.close()
-
- else :
- with open(os.path.join(args.out_dir, 'mse'+str(int(args.lmd))+'lfc_fine.txt'), 'a') as fd:
- fd.write(
- 'ep:%d step:%d time:%.1f lr:%.8f loss:%.6f bpp:%.4f psnr_ave:%.2f\n'
- % (epoch, step, time_used, cur_lr, rec_loss_tmp / (step + 1), bpp_total, psnr))
- fd.close()
-
- if (step + 1) % 1000 == 0:
- if args.fc == 0:
- torch.save(whole.module.state_dict(),
- os.path.join(args.out_dir, 'mse'+str(int(args.lmd))+'.pkl'))
-
- else :
- torch.save(whole.module.state_dict(),
- os.path.join(args.out_dir, 'mse'+str(int(args.lmd))+'lfc.pkl'))
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("--M", type=int, default=256, help="the value of M")
- parser.add_argument("--b_size", type=int, default=6)
- parser.add_argument("--N2", type=int, default=192, help="the value of N2")
- parser.add_argument("--lr", type=float, default=5e-5, help="initial learning rate.")
- parser.add_argument('--out_dir', type=str, default='output')
- parser.add_argument("--gpu", type=int, default=1)
- parser.add_argument("--fc", type=int, default=2)
- parser.add_argument("--lmd", type=float, default=64)
- parser.add_argument("--br", type=float, default=64)
-
- args = parser.parse_args()
- print(args)
-
- train(args)
|