|
- # -*- coding: utf-8 -*-
- """
- @Author : zhwzhong
- @License : (C) Copyright 2013-2018, hit
- @Contact : zhwzhong@hit.edu.cn
- @Software: PyCharm
- @File : main.py
- @Time : 2022/3/1 20:06
- @Desc :
- """
- import os
- import loss
- import data
- import json
- import utils
- import torch
- import models
- from options import args
- from timm.scheduler import create_scheduler
- from trainer import train_one_epoch, evaluate
- from torch.utils.tensorboard import SummaryWriter
- from timm.optim import create_optimizer_v2, optimizer_kwargs
-
- def main():
- model = models.get_model(args)
- device = torch.device(args.device)
- train_data = data.get_loader(args, 'train')
- writer = SummaryWriter('./logs/{}/{}'.format(args.dataset, args.file_name))
-
- model.to(device)
- if args.distributed:
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
- else:
- model = torch.nn.parallel.DataParallel(model, device_ids=list(range(args.num_gpus)))
- model_without_ddp = model.module
- print('===> Parameter Number:', utils.get_parameter_number(model_without_ddp))
- criterion = loss.Loss(args)
- optimizer = create_optimizer_v2(model_without_ddp, **optimizer_kwargs(cfg=args))
- lr_scheduler, _ = create_scheduler(args, optimizer)
- cp_path = './checkpoints/{}/{}'.format(args.dataset, args.file_name)
-
- if os.path.exists(args.load_name) and args.pre_train:
- checkpoint = torch.load('{}'.format(args.load_name), map_location=torch.device('cpu'))
- model_without_ddp.load_state_dict(checkpoint['model'])
- print('Loading Checkpoint form {}...'.format(args.load_name))
-
- if args.resume or args.test_only:
- checkpoint = torch.load('{}/{}'.format(cp_path, args.load_name), map_location=torch.device('cpu'))
- args.start_epoch = checkpoint['epoch'] + 1
- optimizer.load_state_dict(checkpoint['optimizer'])
- model_without_ddp.load_state_dict(checkpoint['model'])
- lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
-
- print('Loading Checkpoint form {}/{}...'.format(cp_path, args.load_name))
-
- if args.resume or args.test_only or args.pre_train:
- for test_name in args.test_set:
- val_data = data.get_loader(args, test_name)
- evaluate(model, criterion, test_name, device=device, val_data=val_data, args=args)
-
- if not args.test_only:
- if args.tail_only_iter > 0:
- utils.freeze_tail(model)
-
- age = 0
- for epoch in range(args.start_epoch, args.epochs):
-
- if 0 < args.tail_only_iter < epoch:
- utils.freeze_tail(model, True)
-
- if args.distributed:
- train_data.sampler.set_epoch(epoch)
-
- train_stats, age = train_one_epoch(model, criterion, train_data, optimizer, device, epoch, args, age)
- log_stats = {**{f'TRAIN_{k}'.upper(): v for k, v in train_stats.items()}}
- lr_scheduler.step(epoch)
-
- for test_name in args.test_set:
- val_data = data.get_loader(args, test_name)
- test_stats = evaluate(model, criterion, test_name, device=device, val_data=val_data, args=args)
- log_stats.update({**{f'{k}'.upper(): v for k, v in test_stats.items()}, 'EPOCH': epoch})
-
- if utils.is_main_process():
- [writer.add_scalar(k.replace('_', '/'), v, epoch) for k, v in log_stats.items() if k != 'EPOCH']
- with open("./logs/{}/{}/log.txt".format(args.dataset, args.file_name), 'a') as f:
- f.write(json.dumps(log_stats) + "\n")
-
- utils.save_on_master({
- 'optimizer': optimizer.state_dict(),
- 'model': model_without_ddp.state_dict(),
- 'lr_scheduler': lr_scheduler.state_dict(),
- 'epoch': epoch,
- 'args': args,
- }, '{}/model_{}.pth'.format(cp_path, str(epoch).zfill(6)))
-
-
- if __name__ == '__main__':
- main()
|