|
- import datetime
- import logging
- import math
- import time
- import os
- import argparse
- import mindspore as ms
- from mindspore import dtype as mstype
- from mindspore import load_checkpoint
- from mindspore import context
- from os import path as osp
- import sys
-
- from basicsr.data import build_dataloader, build_dataset
- from basicsr.data.data_sampler import EnlargedSampler
-
- from basicsr.utils_edvr.options import parse_options
- from basicsr.models import lr_scheduler as lr_scheduler
- from basicsr.archs import build_network
- from basicsr.losses import build_loss
- from mindspore.train.model import Model
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Callback
- from mindspore.nn import Accuracy
- from basicsr.metrics.metric import PSNR
- from basicsr.metrics.metric_util import do_eval
-
-
- class EvalCallBack(Callback):
- """
- eval callback
- """
- def __init__(self, eval_network, ds_val, eval_epoch_frq, epoch_size, metrics, result_evaluation=None):
- self.eval_network = eval_network
- self.ds_val = ds_val
- self.eval_epoch_frq = eval_epoch_frq
- self.epoch_size = epoch_size
- self.result_evaluation = result_evaluation
- self.metrics = metrics
- self.best_result = None
- self.eval_network.set_train(False)
-
- def epoch_end(self, run_context):
- """
- do eval in epoch end
- """
- cb_param = run_context.original_args()
- cur_epoch = cb_param.cur_epoch_num
- if cur_epoch % self.eval_epoch_frq == 0 or cur_epoch == self.epoch_size:
- result = do_eval(self.eval_network, self.ds_val, self.metrics, cur_epoch=cur_epoch)
- if self.best_result is None or self.best_result["psnr"] < result["psnr"]:
- self.best_result = result
- #if get_rank_id() == 0:
- print(f"best evaluation result = {self.best_result}", flush=True)
- if isinstance(self.result_evaluation, dict):
- for k, v in result.items():
- r_list = self.result_evaluation.get(k)
- if r_list is None:
- r_list = []
- self.result_evaluation[k] = r_list
- r_list.append(v)
-
-
-
- def init_net(opt):
- """
- init edsr network
- """
- # define network
- net_g = build_network(opt['network_g'])
-
- # load pretrained models
- load_path = opt['path'].get('pretrain_network_g', None)
- if load_path is not None:
- param_key = opt['path'].get('param_key_g', 'params')
- #load_network(net_g, load_path, opt['path'].get('strict_load_g', True), param_key)
- print(net_g)
- return net_g
-
- def init_opt(opt, net, lr):
- optim_type=opt['train']['optim_g']['type']
- if optim_type == 'Adam':
- optimizer = ms.nn.Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()),learning_rate=lr)
- else:
- raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
- return optimizer
-
- def init_loss(opt):
- # define losses
- if opt['train'].get('pixel_opt'):
- cri_pix = build_loss(opt['train'].get('pixel_opt'))
- else:
- cri_pix = None
- if cri_pix is None:
- raise ValueError('Both pixel and perceptual losses are None.')
- return cri_pix
-
- def get_position_from_periods(iteration, cumulative_period):
- """Get the position from a period list.
-
- It will return the index of the right-closest number in the period list.
- For example, the cumulative_period = [100, 200, 300, 400],
- if iteration == 50, return 0;
- if iteration == 210, return 2;
- if iteration == 300, return 2.
-
- Args:
- iteration (int): Current iteration.
- cumulative_period (list[int]): Cumulative period list.
-
- Returns:
- int: The position of the right-closest number in the period list.
- """
- idx =0
- for i, period in enumerate(cumulative_period):
- if iteration <= period:
- idx = i
- break
- return idx
-
- def CosineAnnealingRestartLR(global_step,base_lr,periods, restart_weights=(1,),eta_min=0):
- cumulative_period = [sum(periods[0:i + 1]) for i in range(0, len(periods))]
- idx = get_position_from_periods(global_step, cumulative_period)
- #for i, period in enumerate(self.cumulative_period):
- # if global_step<=period:
- # idx = i
- # break
- current_weight = restart_weights[idx]
- nearest_restart = 0 if idx == 0 else cumulative_period[idx - 1]
- current_period = periods[idx]
- tempValue = math.pi * ((global_step- nearest_restart) / current_period)
- learn_rate = eta_min+current_weight*0.5*(base_lr-eta_min)*(1+math.cos(tempValue))
- #learn_rate = (self.base_lr-self.eta_min)
- #learn_rate = self.eta_min + current_weight * 0.5 * (self.base_lr - self.eta_min) * (1 + ops.Cos()(3.1415926 * ((global_step - nearest_restart)
- return learn_rate
-
-
- def lr_edvr(opt, current_iter):
- train_opt = opt['train']
- schedulers = [];
- lr=CosineAnnealingRestartLR(current_iter,train_opt['optim_g']['lr'], train_opt['scheduler']['periods'],
- train_opt['scheduler']['restart_weights'],
- train_opt['scheduler']['eta_min'])
-
- return lr
-
- def create_train_val_dataloader(opt,phase):
- # create train and val dataloaders
- train_loader= []
- if phase == 'train':
- dataset_enlarge_ratio = 200
- dataset_opt = opt['datasets']['train']
- train_set = build_dataset(dataset_opt)
- train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
- dataset = build_dataloader(
- train_set,
- phase,
- num_gpu=opt['num_gpu'],
- dist=False,
- sampler=train_sampler,
- seed=opt['manual_seed'])
-
- elif phase in ['val', 'test']:
- dataset_opt = opt['datasets']['val']
- val_set = build_dataset(dataset_opt)
- dataset = build_dataloader(val_set, phase, num_gpu=opt['num_gpu'], dist=False, seed=opt['manual_seed'])
- else:
- raise ValueError(f'Dataset phase {phase} is not recognized.')
-
- return dataset
-
-
- def train_pipeline(root_path):
- # parse options, set distributed setting, set ramdom seed
- opt, args = parse_options(root_path, is_train=True)
- obs_train_url = '/home/ma-user/work/EDVR/model/'
- args.train_url = '/home/work/user-job-dir/outputs/model/'
- args.save_checkpoint_path = args.train_url
- opt['root_path'] = root_path
-
- train_dataset = create_train_val_dataloader(opt,'train')
- val_dataset = create_train_val_dataloader(opt,'val')
- total_iters=3000
-
- # create model
- print("*****************************start build model*****************************")
- net_m = init_net(opt)
- print("load net weights successfully")
-
- loss = init_loss(opt)
- print("load net loss successfully")
-
- lr = []
- for current_iter in range(0,total_iters):
- #print("lr:",lr_edvr(opt, current_iter))
- lr.append(lr_edvr(opt, current_iter))
- print("caculate the learning rate successfully")
- optimizer = init_opt(opt, net_m, lr)
- print("init the optimizer successfully")
-
- metrics = {
- "psnr": PSNR(rgb_range=255, shave=1),
- }
- eval_cb = EvalCallBack(net_m, val_dataset, 300, total_iters, metrics=metrics)
-
- model = Model(net_m, loss_fn=loss, optimizer=optimizer,amp_level="O2")
- print("create Model successfully")
- # training the model
- time_cb = TimeMonitor(data_size=10)
- #config_ck = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps)
- #ckpoint_cb = ModelCheckpoint(prefix="checkpoint_EDVR",directory=args.save_checkpoint_path,config=config_ck)
- loss_cb = LossMonitor()
-
- #eval_param_dict = {"model":model,"dataset":val_loader,"metrics_name":"Accuracy"}
- #eval_cb = EvalCallBack(apply_eval, eval_param_dict,)
- cbs = [time_cb, loss_cb,eval_cb]
- model.train(total_iters, train_dataset, callbacks=cbs,dataset_sink_mode=True, sink_size=10)
-
- ######################## 将输出的模型拷贝到obs ########################
- # 把训练后的模型数据从本地的运行环境拷贝回obs,在启智平台相对应的训练任务中会提供下载
- try:
- mox.file.copy_parallel(args.train_url, obs_train_url)
- print("Successfully Upload {} to {}".format(args.train_url,obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(args.train_url,obs_train_url) + str(e))
- ######################## 将输出的模型拷贝到obs ########################
-
- if __name__ == '__main__':
- ms.context.set_context(mode=ms.context.PYNATIVE_MODE)
- #ms.context.set_context(mode=ms.context.GRAPH_MODE)
- context.set_context(device_target="Ascend")
- root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
- train_pipeline(root_path)
-
-
-
-
|