|
- import argparse
- import os
-
- import numpy as np
- import torch
- import torch.nn as nn
- from PIL import Image
- from torch.autograd import Variable
- from torch.utils.data import DataLoader, Dataset
- from torchvision import transforms
- from torchvision.datasets import ImageFolder
- import Model.model as model
- from Model.context_model import Weighted_Gaussian
- import time
- from traindata import Traindataset
- import warnings
-
- warnings.filterwarnings("ignore")
-
-
- def adjust_learning_rate(optimizer, epoch, init_lr):
- if epoch < 10:
- lr = init_lr
- else:
- lr = init_lr * (0.5 ** ((epoch - 7) // 3))
- if lr < 1e-6:
- lr = 1e-6
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
- return lr
-
-
- def train(args):
- root = '../../../userdata'
-
- t0 = time.time()
- train_data = Traindataset(root)
- print(len(train_data))
- train_loader = DataLoader(train_data, batch_size=args.b_size,
- shuffle=True, num_workers=8)
- print(time.time() - t0)
-
- img_comp = model.Image_coding_multi_hyper_vit(3, args.M, args.N2, args.M, args.M // 2, image_size=256,
- patch_size=32, depth=args.depth, dim_head=64, heads=16).cuda()
- context = Weighted_Gaussian(args.M).cuda()
-
- base_model_existed = os.path.exists(os.path.join(args.out_dir, 'nic_vit1.pkl')) and os.path.exists(
- os.path.join(args.out_dir, 'context_vit1.pkl')) # and False
-
- # model_existed = os.path.exists(os.path.join(args.out_dir,'SQLmse.pkl')) and False
-
- gpu_id = [id for id in range(args.gpu)]
- img_comp = nn.DataParallel(img_comp, device_ids=gpu_id)
- context = nn.DataParallel(context, device_ids=gpu_id)
-
- if base_model_existed:
- img_comp.load_state_dict(torch.load(os.path.join(args.out_dir, 'nic_vit1.pkl')))
- context.load_state_dict(torch.load(os.path.join(args.out_dir, 'context_vit1.pkl')))
- print('resumed the base model')
- else:
- print('base model not found')
-
- opt1 = torch.optim.Adam(img_comp.parameters(), lr=args.lr)
- opt2 = torch.optim.Adam(context.parameters(), lr=args.lr)
- torch.autograd.set_detect_anomaly(True)
- loss_func = nn.MSELoss()
-
- for epoch in range(20):
-
- rec_loss_tmp = 0
- last_time = time.time()
- train_bpp1_tmp = 0
- train_bpp2_tmp = 0
- train_bpp3_tmp = 0
- mse_tmp = 0
- cur_lr = adjust_learning_rate(opt1, epoch, args.lr)
-
- real_step = -1
- for step, batch_x in enumerate(train_loader):
-
- num_pixels = batch_x.size()[0] * \
- batch_x.size()[2] * batch_x.size()[3]
- batch_x = Variable(batch_x).cuda()
-
- fake, xp1, xp2, xq1, hyper_dec, xp3, xq3 = img_comp(batch_x, 1)
- xp1, _ = context(xq1, hyper_dec)
-
- train_bpp1 = torch.sum(torch.log(xp1)) / (-np.log(2) * num_pixels)
- train_bpp2 = torch.sum(torch.log(xp2)) / (-np.log(2) * num_pixels)
- train_bpp3 = torch.sum(torch.log(xp3)) / (-np.log(2) * num_pixels)
-
- train_bpp_total = train_bpp1 + train_bpp2 + train_bpp3
-
- dloss = loss_func(fake, batch_x)
-
- l_rec = args.lmd * dloss + 0.01 * train_bpp_total
-
- opt1.zero_grad()
- opt2.zero_grad()
-
- # if torch.isnan(l_rec).sum()>0: 这句话也可以 但执行时间过长
- a1 = (torch.isnan(l_rec).sum() > 0).item()
-
- if not a1:
- l_rec.backward()
- real_step += 1
-
- torch.nn.utils.clip_grad_norm_(img_comp.parameters(), 5)
- torch.nn.utils.clip_grad_norm_(context.parameters(), 5)
- # torch.nn.utils.clip_grad_norm_(vit_dec.parameters(), 1)
-
- opt1.step()
- opt2.step()
-
- rec_loss_tmp += l_rec.item()
- mse_tmp += dloss.item()
-
- train_bpp1_tmp += train_bpp1.item()
- train_bpp2_tmp += train_bpp2.item()
- train_bpp3_tmp += train_bpp3.item()
-
- if real_step % 100 == 0:
- time_used = time.time() - last_time
- last_time = time.time()
- mse = mse_tmp / (1 + real_step)
- psnr = 10 * np.log10(1.0 / mse)
- bpp_total = (train_bpp1_tmp + train_bpp2_tmp + train_bpp3_tmp) / (1 + real_step)
-
- print(
- 'ep:%d step:%d real_step:%d, time:%.1f loss:%.6f bpp:%.4f psnr_ave:%.2f\n'
- % (epoch, step, real_step, time_used, rec_loss_tmp / (1 + real_step), bpp_total, psnr))
-
- with open(os.path.join(args.out_dir, 'VITmse.txt'), 'a') as fd:
- fd.write(
- 'ep:%d step:%d real_step:%d, time:%.1f loss:%.6f bpp:%.4f psnr_ave:%.2f\n'
- % (epoch, step, real_step, time_used, rec_loss_tmp / (1 + real_step), bpp_total, psnr))
- fd.close()
-
- if (real_step + 2) % 300 == 0:
- print('model written.')
- torch.save(img_comp.state_dict(),
- os.path.join(args.out_dir, 'nic_vit1.pkl'))
- torch.save(context.state_dict(),
- os.path.join(args.out_dir, 'context_vit1.pkl'))
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("--M", type=int, default=192, help="the value of M")
- parser.add_argument("--b_size", type=int, default=6)
- parser.add_argument("--N2", type=int, default=128, help="the value of N2")
- parser.add_argument("--lr", type=float, default=5e-5, help="initial learning rate.")
- parser.add_argument('--out_dir', type=str, default='./VIT/')
- parser.add_argument('--weights_dir', type=str, default='/userdata/Weights.zip/Weights/')
- parser.add_argument("--gpu", type=int, default=1)
- parser.add_argument("--lmd", type=float, default=4)
- parser.add_argument("--depth", type=int, default=10)
-
- args = parser.parse_args()
- print(args)
-
- train(args)
|