|
- from config import *
-
- import json
- import os
- import pprint as pp
- import random
- from datetime import date
- from pathlib import Path
-
- import numpy as np
- import torch
- import torch.backends.cudnn as cudnn
- from torch import optim as optim
-
-
- def setup_train(args):
- set_up_gpu(args)
-
- export_root = create_experiment_export_folder(args)
- export_experiments_config_as_json(args, export_root)
-
- pp.pprint({k: v for k, v in vars(args).items() if v is not None}, width=1)
- return export_root
-
-
- def create_experiment_export_folder(args):
- experiment_dir, experiment_description = args.experiment_dir, args.experiment_description
- if not os.path.exists(experiment_dir):
- os.mkdir(experiment_dir)
- experiment_path = get_name_of_experiment_path(experiment_dir, experiment_description)
- os.mkdir(experiment_path)
- print('Folder created: ' + os.path.abspath(experiment_path))
- return experiment_path
-
-
- def get_name_of_experiment_path(experiment_dir, experiment_description):
- experiment_path = os.path.join(experiment_dir, (experiment_description + "_" + str(date.today())))
- idx = _get_experiment_index(experiment_path)
- experiment_path = experiment_path + "_" + str(idx)
- return experiment_path
-
-
- def _get_experiment_index(experiment_path):
- idx = 0
- while os.path.exists(experiment_path + "_" + str(idx)):
- idx += 1
- return idx
-
-
- def load_weights(model, path):
- pass
-
-
- def save_test_result(export_root, result):
- filepath = Path(export_root).joinpath('test_result.txt')
- with filepath.open('w') as f:
- json.dump(result, f, indent=2)
-
-
- def export_experiments_config_as_json(args, experiment_path):
- with open(os.path.join(experiment_path, 'config.json'), 'w') as outfile:
- json.dump(vars(args), outfile, indent=2)
-
-
- def fix_random_seed_as(random_seed):
- random.seed(random_seed)
- torch.manual_seed(random_seed)
- torch.cuda.manual_seed_all(random_seed)
- np.random.seed(random_seed)
- cudnn.deterministic = True
- cudnn.benchmark = False
-
-
- def set_up_gpu(args):
- os.environ['CUDA_VISIBLE_DEVICES'] = args.device_idx
- args.num_gpu = len(args.device_idx.split(","))
-
-
- def load_pretrained_weights(model, path):
- chk_dict = torch.load(os.path.abspath(path))
- model_state_dict = chk_dict[STATE_DICT_KEY] if STATE_DICT_KEY in chk_dict else chk_dict['state_dict']
- model.load_state_dict(model_state_dict)
-
-
- def setup_to_resume(args, model, optimizer):
- chk_dict = torch.load(os.path.join(os.path.abspath(args.resume_training), 'models/checkpoint-recent.pth'))
- model.load_state_dict(chk_dict[STATE_DICT_KEY])
- optimizer.load_state_dict(chk_dict[OPTIMIZER_STATE_DICT_KEY])
-
-
- def create_optimizer(model, args):
- if args.optimizer == 'Adam':
- return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
-
- return optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
-
-
- class AverageMeterSet(object):
- def __init__(self, meters=None):
- self.meters = meters if meters else {}
-
- def __getitem__(self, key):
- if key not in self.meters:
- meter = AverageMeter()
- meter.update(0)
- return meter
- return self.meters[key]
-
- def update(self, name, value, n=1):
- if name not in self.meters:
- self.meters[name] = AverageMeter()
- self.meters[name].update(value, n)
-
- def reset(self):
- for meter in self.meters.values():
- meter.reset()
-
- def values(self, format_string='{}'):
- return {format_string.format(name): meter.val for name, meter in self.meters.items()}
-
- def averages(self, format_string='{}'):
- return {format_string.format(name): meter.avg for name, meter in self.meters.items()}
-
- def sums(self, format_string='{}'):
- return {format_string.format(name): meter.sum for name, meter in self.meters.items()}
-
- def counts(self, format_string='{}'):
- return {format_string.format(name): meter.count for name, meter in self.meters.items()}
-
-
- class AverageMeter(object):
- """Computes and stores the average and current value"""
-
- def __init__(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val
- self.count += n
- self.avg = self.sum / self.count
-
- def __format__(self, format):
- return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format)
|