|
- from genericpath import exists
- import os
- import torch
- import logging
- import random
- import numpy as np
- import warnings
- from args import get_parser
- from dataset.data import data_prefetcher, AllData
- import torch.distributed as dist
- from models.sfcn import SFCN
- from models.vgg import vgg16_bn
- from models.dbn import DBN
- from models.resnet import resnet18, resnet34, resnet50
- from models.densenet import densenet121, densenet201
- from apex import amp
- from apex.parallel import DistributedDataParallel
- import torch.backends.cudnn as cudnn
- from torch.utils.data import DataLoader
- from torch.autograd import Variable
- from utils.utils import reduce_mean,adjust_learning_rate, AverageMeter, ProgressMeter, my_KLDivLoss, weight_MSE, weight_kdloss
-
- import torch.nn.functional as F
-
- def initialize():
- # get args
- args = get_parser()
-
- # warnings
- warnings.filterwarnings("ignore")
-
- # logger
- logger = logging.getLogger(__name__)
-
- # set seed
- seed = int(1111)
- random.seed(seed)
- os.environ['PYTHONHASHSEED'] = str(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- torch.backends.cudnn.enabled = True
-
- # initialize logger
- logger.setLevel(level = logging.INFO)
-
- if not os.path.exists("logs"):
- os.makedirs("logs")
-
- handler = logging.FileHandler("logs/%s.txt" % args.env_name)
- handler.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
-
- console = logging.StreamHandler()
- console.setLevel(logging.INFO)
-
- logger.addHandler(handler)
- logger.addHandler(console)
- return args, logger
-
- def main():
- config, logger = initialize()
- config.nprocs = torch.cuda.device_count()
- main_worker(config, logger)
-
- def main_worker(config, logger):
- model_names = ["resnet18", "resnet50", "vgg", "dense121", "sfcn", "dbn"]
- models = [resnet18, resnet50,vgg16_bn, densenet121, SFCN, DBN]
-
- best_acc1 = 99.0
-
- dist.init_process_group(backend='nccl')
- # create model
- model = models[model_names.index(config.arch)](output_dim=88, mode = config.mode)
- print(model_names.index(config.arch),config.arch,config.mode)
- #org_checkpoint = torch.load('checkpoints/dbn_tfkd_tune-model-dbn-_mode-2-_lam-0.2-_alpha-0.2/dbn_tfkd_tune-model-dbn-_mode-2-_lam-0.2-_alpha-0.2_epoch_106_2.4427334349999894')
- #model.load_state_dict(org_checkpoint['state_dict'])
- ########################################################################################################
- T_model = models[model_names.index(config.arch)](output_dim=88, mode = config.mode)
- if config.mode == 0:
- checkpoint = torch.load('./baseline/model-%s-_mode-0' % config.arch)
- T_model.load_state_dict(checkpoint['state_dict'])
- elif config.mode == 1:
- checkpoint = torch.load('./baseline/model-%s-_mode-1' % config.arch)
- T_model.load_state_dict(checkpoint['state_dict'])
- elif config.mode == 2:
- checkpoint = torch.load('./baseline/model-%s-_mode-2' % config.arch)
- T_model.load_state_dict(checkpoint['state_dict'])
-
- torch.cuda.set_device(config.local_rank)
- model.cuda()
- T_model.cuda()
-
- config.batch_size = int(config.batch_size / config.nprocs)
-
- optimizer = torch.optim.Adam(model.parameters(),lr = config.lr,weight_decay = 0.00005)
- #optimizer.load_state_dict(org_checkpoint['optimizer'])
-
- model, optimizer = amp.initialize(model, optimizer, opt_level=config.opt_level)
- #amp.load_state_dict(org_checkpoint['amp'])
- model = DistributedDataParallel(model)
-
- cudnn.benchmark = True
-
- # Data loading code
- train_data = AllData(config.data, train = True)
- val_data = AllData(config.data, train = False)
-
- if config.mode in [3,4,6,9]:
- all_predictions = torch.zeros((len(train_data)))#.cuda()
- all_predictions_kl = torch.zeros((len(train_data)))#.cuda()
- elif config.mode == 7:
- all_predictions = torch.zeros((len(train_data)))#.cuda()
- all_predictions_kl = torch.zeros((len(train_data), 22))#.cuda()
- else:
- all_predictions = torch.zeros((len(train_data), 22))#.cuda()
- all_predictions_kl = torch.zeros((len(train_data), 22))#.cuda()
-
- train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
- val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
-
- train_loader = DataLoader(train_data,config.batch_size,
- shuffle=False,num_workers=8,pin_memory = True, sampler = train_sampler)
- val_loader = DataLoader(val_data,config.batch_size,
- shuffle=False,num_workers=4,pin_memory = True, sampler = val_sampler)
-
-
- for epoch in range(config.epochs):
- train_sampler.set_epoch(epoch)
-
- adjust_learning_rate(optimizer, epoch, config)
-
- alpha_t = 0.8 * ((epoch + 1) / config.epochs)
- alpha_t = max(0, alpha_t)
-
- # train for one epoch
- train(train_loader, model, T_model, optimizer, epoch, config, all_predictions,all_predictions_kl,alpha_t, logger)
-
- mae = validate(val_loader, model, config, logger)
-
- is_best = mae < best_acc1
- best_acc1 = min(mae, best_acc1)
- if not os.path.exists("./checkpoints/%s" % config.env_name):
- try:
- os.makedirs("./checkpoints/%s" % config.env_name)
- except:
- pass # multiple processors bug
-
- if is_best and config.local_rank == 0:
- #if 1 > 0:
- if 1 > 0:
- state = {
- 'epoch': epoch + 1,
- 'state_dict': model.module.state_dict(),
- 'best_acc1': best_acc1,
- 'amp': amp.state_dict(),
- 'optimizer': optimizer.state_dict(),
- }
- torch.save(state, './checkpoints/%s/%s_epoch_%s_%s' % (config.env_name, config.env_name, epoch, best_acc1))
-
- def train(train_loader, model, T_model , optimizer, epoch, config, all_predictions,all_predictions_kl, alpha_t, logger):
- losses = AverageMeter('Loss', ':.4e')
- loss_mae = AverageMeter('mae1', ':6.2f')
-
- progress = ProgressMeter(len(train_loader), [losses, loss_mae],
- prefix="Epoch: [{}]".format(epoch), logger = logger)
-
- model.train()
- T_model.eval()
-
- prefetcher = data_prefetcher(train_loader)
- images, target, yy, bc, indices = prefetcher.next()
- i = 0
- optimizer.zero_grad()
- optimizer.step()
-
- T = config.T
- alpha = config.alpha
- beta = config.beta
- scale = config.scale
- while images is not None:
-
- out, out_p, rep = model(images, use_mine = True)
- T_out, T_out_p, T_rep = T_model(images, use_mine = True)
-
- if config.mode % 3 == 0:
- mae = torch.nn.L1Loss()(out, target)
- else:
- prob = torch.exp(out)
- pred = torch.sum(prob * bc, dim = 1)
- mae = torch.nn.L1Loss()(pred, target)
- T_pred = torch.sum(torch.exp(T_out) * bc, dim = 1)
-
- ################################################################################################################################
- #------------------------------------------------- baseline ---------------------------------
- if config.mode == 0:
- # baseline mse
- loss = torch.nn.MSELoss()(out, target)
- losskd = torch.nn.MSELoss()(out, T_out)
- loss = loss*(1-alpha) + losskd*alpha
-
- elif config.mode == 1:
- # baseline dex
- loss = torch.nn.MSELoss()(pred, target)
- kdloss = torch.nn.KLDivLoss()
- ###
- scorea = F.log_softmax(out_p/T,dim=1)
- scoreb = F.softmax(T_out_p/T,dim=1)
- losskd = kdloss(scorea,scoreb)*T*T
-
- loss = loss*(1-alpha) + losskd*alpha
-
- elif config.mode == 2:
- loss = my_KLDivLoss(out, yy)
- kdloss = torch.nn.KLDivLoss()
-
- scorea = F.log_softmax(out_p/T,dim=1)
- scoreb = F.softmax(T_out_p/T,dim=1)
- losskd = kdloss(scorea,scoreb)*T*T
-
- loss = loss*(1-alpha) + losskd*alpha
-
- elif config.mode == 3:
- # baseline mse
- weight = abs(T_out - target) / scale
- weight[weight >1] = 1
- weight = 1 - weight
-
- loss = torch.nn.MSELoss()(out, target)
- losskd = weight_MSE(out, T_out, weight)
- lossrep = weight_MSE(rep, T_rep, weight)
- loss = loss + losskd*alpha + beta *lossrep
-
- elif config.mode == 4:
- # baseline dex
- weight = abs(T_out - target) / scale
- weight[weight >1] = 1
- weight = 1 - weight
-
- loss = torch.nn.MSELoss()(pred, target)
- kdloss = torch.nn.KLDivLoss()
- ###
- scorea = F.log_softmax(out_p/T,dim=1)
- scoreb = F.softmax(T_out_p/T,dim=1)
- losskd = weight_kdloss(scorea,scoreb,weight)*T*T
- lossrep = weight_MSE(rep, T_rep, weight)*T*T
-
- loss = loss + losskd * alpha + lossrep * beta
-
- elif config.mode == 5:
- # baseline soft label
- loss = my_KLDivLoss(out, yy)
- weight = abs(T_pred - target) / scale
- weight[weight >1] = 1
- weight = 1 - weight
-
- kdloss = torch.nn.KLDivLoss()
- ###
- scorea = F.log_softmax(out_p/T,dim=1)
- scoreb = F.softmax(T_out_p/T,dim=1)
-
- losskd = weight_kdloss(scorea,scoreb, weight)*T*T
-
- lossrep = weight_MSE(rep, T_rep, weight)*T*T
-
- loss = loss + losskd*alpha + beta *lossrep
- else:
- print("Not Support.")
- exit()
-
- torch.distributed.barrier()
-
- reduced_loss = reduce_mean(loss, config.nprocs)
- reduced_mae = reduce_mean(mae, config.nprocs)
-
- losses.update(reduced_loss.item(), images.size(0))
- loss_mae.update(reduced_mae.item(), images.size(0))
-
- optimizer.zero_grad()
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- optimizer.step()
-
- if i % config.print_freq == 0:
- progress.display(i)
-
- i += 1
-
- images, target, yy, bc,indices = prefetcher.next()
-
- logger.info("[train mae]: %.4f" % float(loss_mae.avg))
-
-
- def validate(val_loader, model, config, logger):
-
- loss_mae = AverageMeter('mae1', ':6.2f')
- progress = ProgressMeter(len(val_loader), [loss_mae], prefix='Test: ', logger = logger)
- model.eval()
-
- with torch.no_grad():
- prefetcher = data_prefetcher(val_loader)
- images, target, yy, bc,indices = prefetcher.next()
- i = 0
- while images is not None:
-
- out = model(images)
-
- if config.mode % 3 == 0:
- mae = torch.nn.L1Loss()(out, target)
- else:
- prob = torch.exp(out)
- pred = torch.sum(prob * bc, dim = 1)
- mae = torch.nn.L1Loss()(pred, target)
-
- torch.distributed.barrier()
- reduced_mae = reduce_mean(mae, config.nprocs)
- loss_mae.update(reduced_mae.item(), images.size(0))
-
- if i % config.print_freq == 0:
- progress.display(i)
-
- i += 1
-
- images, target, yy, bc, indices = prefetcher.next()
-
- logger.info("[val mae]: %.4f" % float(loss_mae.avg))
- return loss_mae.avg
-
-
- if __name__ == '__main__':
- main()
|