|
- from sys import prefix
- import os
- import torch
- import logging
- import random
- import numpy as np
- import warnings
- import time
- from args import get_parser
- from dataset.data_lmdb import data_prefetcher, AllData_DataFrame
- import torch.distributed as dist
- from models.sfcn_miniV2 import SFCN
- from models.vgg import vgg16_bn
- from models.dbnV2 import DBN
- from models.resnetV2 import resnet18, resnet34, resnet50
- from models.densenetV2 import densenet121, densenet201
- from apex import amp
- import torch.nn.functional as F
- from apex.parallel import DistributedDataParallel
- import torch.backends.cudnn as cudnn
- from torch.utils.data import DataLoader
- from torch.autograd import Variable
- from utils.utils import reduce_mean,adjust_learning_rate, AverageMeter, ProgressMeter, my_KLDivLoss
- root = "/".join(__file__.split("/")[:-1])
- def initialize():
- # get args
- args = get_parser()
-
- # warnings
- warnings.filterwarnings("ignore")
-
- # logger
- logger = logging.getLogger(__name__)
-
- # set seed
- seed = int(1111)
- random.seed(seed)
- os.environ['PYTHONHASHSEED'] = str(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- torch.backends.cudnn.enabled = True
-
- # initialize logger
- logger.setLevel(level = logging.INFO)
-
- if not os.path.exists(f"{root}/logs"):
- os.makedirs(f"{root}/logs")
-
- handler = logging.FileHandler(f"{root}/logs/{args.env_name}.txt")
- handler.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
-
- console = logging.StreamHandler()
- console.setLevel(logging.INFO)
-
- logger.addHandler(handler)
- logger.addHandler(console)
- return args, logger
-
- def main():
- config, logger = initialize()
- config.nprocs = torch.cuda.device_count()
- main_worker(config, logger)
-
- def main_worker(config, logger):
- model_names = ["resnet18", "resnet50", "vgg", "dense121", "sfcn", "dbn"]
- models = [resnet18, resnet50,vgg16_bn, densenet121, SFCN, DBN]
-
- best_mae = 99.0
- dist.init_process_group(backend='nccl')
- # create model
- model_t = models[model_names.index(config.arch)](output_dim=88, mode = config.mode)
- model_s = models[model_names.index(config.arch)](output_dim=88, mode = config.mode)
-
- torch.cuda.set_device(config.local_rank)
-
- # find the best teacher epoch
- dirs = f"{root}/checkpoints/{config.arch}_teacher_fold-{config.fold}-model-{config.arch}"
- files = os.listdir(dirs)
- trained_epoch = [int(f.split("_")[-2]) for f in files]
- max_epoch = max(trained_epoch)
- use_file = files[trained_epoch.index(max_epoch)]
- logger.info(f"Loaded file: {use_file}")
- checkpoint = torch.load(os.path.join(dirs,use_file), map_location='cpu')
- model_t.load_state_dict(checkpoint['state_dict'])
-
- model_t.cuda().eval()
- model_s.cuda()
-
- config.batch_size = int(config.batch_size / config.nprocs)
-
- optimizer = torch.optim.Adam(model_s.parameters(),lr = config.lr,weight_decay = 0.0001)
- #optimizer.load_state_dict(checkpoint['optimizer'])
-
- model_s, optimizer = amp.initialize(model_s, optimizer, opt_level=config.opt_level)
- #amp.load_state_dict(checkpoint['amp'])
- model_s = DistributedDataParallel(model_s)
-
- cudnn.benchmark = True
-
- # Data loading code
- use_lmdb = True
- # lmdb:
- if use_lmdb is True:
- from dataset.data_lmdb import AllData_DataFrame
- train_data = AllData_DataFrame(f"{root}/dataset/brain_age_dataset.csv",f"{root}/dataset/brain_age/data.lmdb",config, train = True,generate_age_dist=True)
- val_data = AllData_DataFrame(f"{root}/dataset/brain_age_dataset.csv",f"{root}/dataset/brain_age/data.lmdb",config, train = False,generate_age_dist=True)
- else:
- from dataset.data import AllData_DataFrame
- train_data = AllData_DataFrame(f"{root}/dataset/brain_age_dataset.csv",config,replace_path="/data3/yangyw/Datas", train = True,generate_age_dist=True)
- val_data = AllData_DataFrame(f"{root}/dataset/brain_age_dataset.csv",config,replace_path="/data3/yangyw/Datas", train = False,generate_age_dist=True)
-
-
- train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
- val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
- #test_sampler = torch.utils.data.distributed.DistributedSampler(test_data)
-
- train_loader = DataLoader(train_data,config.batch_size,
- shuffle=False,num_workers=8,pin_memory = True, sampler = train_sampler)
- val_loader = DataLoader(val_data,config.batch_size,
- shuffle=False,num_workers=4,pin_memory = False, sampler = val_sampler)
- #test_loader = DataLoader(test_data,config.batch_size,
- # shuffle=False,num_workers=8,pin_memory = False, sampler = test_sampler)
-
- for epoch in range(config.epochs):
- train_sampler.set_epoch(epoch)
-
- adjust_learning_rate(optimizer, epoch, config)
-
- # train for one epoch
- st_time = time.time()
- train(train_loader, model_t,model_s, optimizer, epoch, config, logger)
- end_time = time.time()
-
- run_time = end_time - st_time
- logger.info(f"Running time for epoch: {run_time}")
- mae = validate(val_loader,model_s, config, logger, prefix = "val")
- #test_acc = validate(test_loader,model_t, config, logger, prefix = "test")
-
- is_best = (mae < best_mae)
-
- best_mae = min(mae, best_mae)
- save_dir = f"{root}/checkpoints/{config.env_name}"
- if not os.path.exists(save_dir):
- try:
- os.makedirs(save_dir)
- except:
- pass # multiple processors bug
-
- if is_best and config.local_rank == 0:
- state = {
- 'epoch': epoch + 1,
- 'state_dict': model_s.module.state_dict(),
- 'best_acc1': best_mae,
- 'amp': amp.state_dict(),
- 'optimizer': optimizer.state_dict(),
- }
- torch.save(state, f'{save_dir}/{config.env_name}_epoch_{epoch}_{best_mae}')
-
- def accuracy(output, labels):
- preds = output.max(1)[1].type_as(labels)
- correct = preds.eq(labels).double()
- correct = correct.sum()
- return correct / len(labels)
-
- def train(train_loader, model_t,model_s, optimizer, epoch, config,logger):
- #return
- losses = AverageMeter('Loss', ':.4e')
- loss_mae = AverageMeter('MAE', ':6.2f')
-
- progress = ProgressMeter(len(train_loader), [losses, loss_mae],
- prefix="Epoch: [{}]".format(epoch), logger = logger)
-
- model_s.train()
-
- prefetcher = data_prefetcher(train_loader)
- images, target, yy, bc, indices = prefetcher.next()
- i = 0
- optimizer.zero_grad()
- optimizer.step()
- while images is not None:
- out_t = model_t(images)
- out_s = model_s(images)
- pred = torch.sum(out_s['p'] * bc, dim = 1)
-
- loss = F.kl_div(out_s['y'], yy) + torch.nn.MSELoss()(out_s['fea'][-1], out_t['fea'][-1])
- mae = F.l1_loss(pred, target)
-
- torch.distributed.barrier()
- reduced_loss = reduce_mean(loss, config.nprocs)
- reduced_mae = reduce_mean(mae, config.nprocs)
-
- losses.update(reduced_loss.item(), images.size(0))
- loss_mae.update(reduced_mae.item(), images.size(0))
-
- optimizer.zero_grad()
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- optimizer.step()
-
- if i % config.print_freq == 0 and config.local_rank == 0:
- progress.display(i)
-
- i += 1
-
- images, target, yy, bc,indices = prefetcher.next()
-
- if config.local_rank == 0:
- logger.info(f"[train loss]: {round(float(losses.avg),4)}, [train mae]: {round(float(loss_mae.avg),4)}")
-
-
- from sklearn.metrics import roc_auc_score,f1_score,confusion_matrix
- def validate(val_loader, model, config, logger, prefix = ""):
-
- loss_metric = AverageMeter('mae', ':6.2f')
- model.eval()
-
- with torch.no_grad():
- prefetcher = data_prefetcher(val_loader)
- images, target, yy, bc,indices = prefetcher.next()
- while images is not None:
- out = model(images)
- pred = torch.sum(out['p'] * bc, dim = 1)
-
- #loss = F.kl_div(out_t['y'], yy)
- mae = F.l1_loss(pred, target)
-
- torch.distributed.barrier()
- reduced_mae = reduce_mean(mae, config.nprocs)
- loss_metric.update(reduced_mae.item(), images.size(0))
-
- images, target, yy, bc, indices = prefetcher.next()
- if config.local_rank == 0:
- logger.info(f"\033[32m >>>>>>>> [{prefix}-mae]: {round(float(loss_metric.avg),4)} \033[0m")
-
- return round(loss_metric.avg,4)
-
-
- if __name__ == '__main__':
- main()
|