|
- # -*- coding: utf-8 -*-
- """
- @Author : zhwzhong
- @License : (C) Copyright 2013-2018, hit
- @Contact : zhwzhong@hit.edu.cn
- @Software: PyCharm
- @File : trainer.py
- @Time : 2022/3/1 20:06
- @Desc :
- """
- import gc
- import os
- import utils
- import torch
- import numpy as np
- from PIL import Image
-
- def train_one_epoch(model, criterion, train_data, optimizer, device, epoch, args, age):
- model.train()
- criterion.train()
- metric_logger = utils.MetricLogger(delimiter=" ")
- metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
- header = 'Epoch: [{}]'.format(epoch)
-
- bum = len(train_data)
- for samples in metric_logger.log_every(train_data, args.print_freq, header):
- age += 1.0 / bum
- samples = utils.to_device(samples, device)
- out = model(utils.mix_up(samples, args.mix_alpha) if args.mix_up else samples)
-
- if args.split_loss:
- loss = 0
- for out_img, gt_img in zip(torch.split(out['out_img'], 2, 1), torch.split(samples['gt_img'], 2, 1)):
- loss += criterion(out_img, gt_img, age)
- else:
- loss = criterion(out['out_img'], samples['gt_img'], age)
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- torch.cuda.synchronize()
- metric_logger.update(loss=loss.item())
- metric_logger.update(lr=optimizer.param_groups[0]["lr"])
- del samples, out
- gc.collect()
- metric_logger.synchronize_between_processes()
- torch.cuda.empty_cache()
- return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, age
-
-
- @torch.no_grad()
- def evaluate(model, criterion, test_name, val_data, device, args):
- model.eval()
- metric_logger = utils.MetricLogger(delimiter=" ")
- header = '{}:'.format(test_name)
-
- sv_path = './results/{}/{}/'.format(args.dataset, args.file_name)
- if args.save_result:
- utils.create_dir(sv_path)
-
- for samples in metric_logger.log_every(val_data, 1, header):
- samples = utils.to_device(samples, device)
- out = utils.self_ensemble(samples, model, args.ensemble_mode) if args.self_ensemble else model(samples)
- # loss = criterion(out['out_img'], samples['gt_img'], 0)
- # metric_logger.update(loss=loss.item())
-
- for i in range(samples['gt_img'].size(0)):
-
- out_img, gt_img, pad = out['out_img'][i: i+ 1], samples['gt_img'][i: i+ 1], samples['pad']
- top, bottom, left, right = pad['top'][i: i+ 1], pad['bottom'][i: i+ 1], pad['left'][i: i+ 1], pad['right'][i: i+ 1]
-
- gt_img = gt_img[:, :, top: bottom, left: right]
- out_img = out_img[:, :, top: bottom, left: right]
- metrics = utils.calc_metrics(out_img, gt_img, args)
-
- img_name = samples['img_name']
- if args.save_result:
- print('Image Saved to {}'.format(sv_path + img_name[i]))
- out_img = utils.tensor2uint(out_img, data_range=args.data_range)
-
- left_img = out_img[:, :, :3]
- right_img = out_img[:, :, 3:]
- Image.fromarray(left_img).save(os.path.join(sv_path, img_name[i].replace('rgb.npy', 'L.png')))
- Image.fromarray(right_img).save(os.path.join(sv_path, img_name[i].replace('rgb.npy', 'R.png')))
-
- # metric_logger.update(loss=loss.item())
-
- for metric, value in metrics.items():
- metric_logger.meters[metric].update(value, n=1)
- del samples, out
- gc.collect()
- torch.cuda.empty_cache()
-
- metric_logger.synchronize_between_processes()
-
- metric_out = {'{}_'.format(test_name) + k: round(meter.global_avg, 3) for k, meter in metric_logger.meters.items()}
- print(metric_out)
- torch.cuda.empty_cache()
- return metric_out
|