|
- import os
- import sys
- import argparse
- sys.path.append('core')
- import numpy as np
- import random
- import logging
- import mindspore as ms
- from mindspore import context
- import mindspore.nn as nn
- import mindspore.dataset as ds
- from mindspore.profiler import Profiler
- from mindspore.train.callback import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
- from mindspore import dtype as mstype
-
-
- from core.ms_datasets import MpiSintel, FlyingChairs, FlyingThings3D, KITTI, HD1K
- from core.network import RAFTGMA
- from core.logger import get_logger
- from core.seq_loss import MyOptimizer, MyWithLossCell, SequenceLossCell, CustomTrainOneStepCell
- from core.train_one_step_loss_scale import TrainOneStepWithLossScaleCell,GMALossMonitor
- from mindspore.communication import init, get_rank, get_group_size
- from mindspore.context import ParallelMode
-
- from mindspore import Tensor, export, load_checkpoint, load_param_into_net, save_checkpoint
-
-
-
- # for openi
- def get_openi_default_args(args):
- args.device_target = 'Ascend'
- args.name = 'gma-sintel'
- args.stage = 'sintel'
- args.validatiion = 'sintel'
- args.num_steps = 120000
- args.lr = 0.000125
- args.image_size = [368, 768]
- args.wdecay = 0.00001
- args.gamma = 0.85
- args.dataset_path = '/mass_store/dataset/sintel'
-
- def fetch_dataloader(args, rank_size = 0, rank_id = 0, TRAIN_DS='C+T+K+S+H'):
-
- if not args.is_dist:
- rank_size = None
- rank_id = None
- num_workers = 16
- multipcond = False
-
- if args.stage == 'chairs':
- aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
- train_dataset = FlyingChairs(aug_params, split='training')
- train_loader = ds.GeneratorDataset(train_dataset, column_names=["img1", "img2", "flow", "valid"],
- num_parallel_workers=8, shuffle=True)
- train_loader = train_loader.batch(args.batch_size, drop_remainder=True)
- return train_loader
-
- elif args.stage == 'things':
- aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
- clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass', split='training')
- final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass', split='training')
- clean_loader = ds.GeneratorDataset(clean_dataset, column_names=["img1", "img2", "flow", "valid"], shuffle=True)
- final_loader = ds.GeneratorDataset(final_dataset, column_names=["img1", "img2", "flow", "valid"], shuffle=True)
- train_loader = clean_loader.concat(final_loader)
- train_loader = train_loader.batch(args.batch_size, drop_remainder=True)
- # train_dataset = clean_dataset + final_dataset
- return train_loader
-
- elif args.stage == 'sintel':
- aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
- things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
- sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
- sintel_final = MpiSintel(aug_params, split='training', dstype='final')
-
- if TRAIN_DS == 'C+T+K+S+H':
- kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
- hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
- # train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
- sintel_clean_loader = ds.GeneratorDataset(sintel_clean, column_names=["img1", "img2", "flow", "valid"],
- shuffle=True,num_parallel_workers=num_workers,
- python_multiprocessing=multipcond,
- num_shards=rank_size, shard_id=rank_id).repeat(100)
- sintel_final_loader = ds.GeneratorDataset(sintel_final, column_names=["img1", "img2", "flow", "valid"],
- shuffle=True,num_parallel_workers=num_workers,
- python_multiprocessing=multipcond,
- num_shards=rank_size, shard_id=rank_id).repeat(100)
- thing_loder = ds.GeneratorDataset(things, column_names=["img1", "img2", "flow", "valid"], shuffle=True,
- num_parallel_workers=num_workers,
- python_multiprocessing=multipcond,num_shards=rank_size, shard_id=rank_id)
- kitti_loder = ds.GeneratorDataset(kitti, column_names=["img1", "img2", "flow", "valid"],
- shuffle=True,num_parallel_workers=num_workers,
- python_multiprocessing=multipcond,
- num_shards=rank_size, shard_id=rank_id).repeat(200)
- hd1k_loder = ds.GeneratorDataset(hd1k, column_names=["img1", "img2", "flow", "valid"],
- shuffle=True,num_parallel_workers=num_workers,
- python_multiprocessing=multipcond,
- num_shards=rank_size, shard_id=rank_id).repeat(5)
- train_loader = sintel_clean_loader.concat(sintel_final_loader).concat(kitti_loder).concat(hd1k_loder).concat(thing_loder)
- # train_loader = sintel_clean_loader.concat(sintel_final_loader)
- train_loader = train_loader.batch(batch_size=args.batch_size, drop_remainder=True)
- return train_loader
-
- elif TRAIN_DS == 'C+T+K/S':
- sintel_clean_loader = ds.GeneratorDataset(sintel_clean, column_names=["img1", "img2", "flow", "valid"],
- shuffle=True).repeat(100)
- sintel_final_loader = ds.GeneratorDataset(sintel_final, column_names=["img1", "img2", "flow", "valid"],
- shuffle=True).repeat(100)
- thing_loder = ds.GeneratorDataset(things, column_names=["img1", "img2", "flow", "valid"], shuffle=True)
- train_loader = sintel_clean_loader.concat(sintel_final_loader).concat(thing_loder)
- train_loader = train_loader.batch(args.batch_size, drop_remainder=True)
- # train_dataset = 100*sintel_clean + 100*sintel_final + things
- return train_loader
-
- elif args.stage == 'kitti':
- aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
- train_dataset = KITTI(aug_params, split='training')
- print("+++++Training with %d image pairs" % len(train_dataset))
- train_loader = ds.GeneratorDataset(source=train_dataset, num_parallel_workers=8,
- column_names=['img1', 'img2', 'flow', 'valid'], shuffle=True,
- num_shards=rank_size, shard_id=rank_id)
- train_loader = train_loader.batch(batch_size=args.batch_size, drop_remainder=True)
-
- return train_loader
-
- def mock(dataloader):
- for id, data_blob in enumerate(dataloader.create_tuple_iterator()):
- image1, image2, flow, valid = data_blob
- np.save("image1.npy", image1.asnumpy())
- np.save("image2.npy", image2.asnumpy())
- np.save("flow.npy", flow.asnumpy())
- np.save("valid.npy", valid.asnumpy())
- print("[DEBUG] saved!!!!")
- break
-
- class OneCycleLR:
- def __init__(self, max_lr, total_steps, pct_start):
- self.total_steps = total_steps
- self._schedule_phases = [
- {
- 'end_step': float(pct_start * self.total_steps) - 1,
- 'start_lr': 'initial_lr',
- 'end_lr': 'max_lr',
- 'start_momentum': 'max_momentum',
- 'end_momentum': 'base_momentum',
- },
- {
- 'end_step': self.total_steps - 1,
- 'start_lr': 'max_lr',
- 'end_lr': 'min_lr',
- 'start_momentum': 'base_momentum',
- 'end_momentum': 'max_momentum',
- },
- ]
- div_factor = 25.
- final_div_factor = 1e4
- # max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
- self.initial_lr = max_lr / div_factor
- self.max_lr = max_lr
- self.min_lr = self.initial_lr / final_div_factor
- self.group = {'initial_lr': self.initial_lr, 'max_lr': self.max_lr, 'min_lr': self.min_lr}
- # if last_epoch == -1:
- # for idx, group in enumerate(self.optimizer.param_groups):
- # group['initial_lr'] = max_lrs[idx] / div_factor
- # group['max_lr'] = max_lrs[idx]
- # group['min_lr'] = group['initial_lr'] / final_div_factor
-
- def _annealing_linear(self, start, end, pct):
- "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
- return (end - start) * pct + start
-
- def get_lrs(self):
- lrs = []
- step_num = 0
- while step_num < self.total_steps:
- # for group in self.optimizer.param_groups:
- start_step = 0
- for i, phase in enumerate(self._schedule_phases):
- end_step = phase['end_step']
- if step_num <= end_step or i == 1:
- pct = (step_num - start_step) / (end_step - start_step)
- computed_lr = self._annealing_linear(self.group[phase['start_lr']], self.group[phase['end_lr']], pct)
- break
- start_step = phase['end_step']
- lrs.append(computed_lr)
- step_num = step_num + 1
-
- return lrs
-
- class StopAtStep(ms.Callback):
- def __init__(self, start_step, stop_step):
- super(StopAtStep, self).__init__()
- self.start_step = start_step
- self.stop_step = stop_step
- self.profiler = ms.Profiler(start_profile=False)
- def step_begin(self, run_context):
- cb_params = run_context.original_args()
- step_num = cb_params.cur_step_num
- if step_num == self.start_step:
- print("[DEBUG] profiler start")
- self.profiler.start()
- def step_end(self, run_context):
- cb_params = run_context.original_args()
- step_num = cb_params.cur_step_num
- if step_num == self.stop_step:
- print("[DEBUG] profiler stop")
- self.profiler.stop()
- self.profiler.analyse()
- print('[DEBUG] profiler analyse')
- def end(self, run_context):
- print('[DEBUG] end')
- self.profiler.analyse()
-
- def main(args):
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
- device_num = int(os.getenv('RANK_SIZE', '1'))
- # print("[DEBUG] 分布式模式:", args.is_dist)
- # print("[DEBUG] Device Num: ", device_num)
- if args.is_dist:
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
- init()
- rank_id = get_rank()
- rank_size = get_group_size()
- logs = get_logger(args.output, rank_id)
- else:
- rank_id = 0
- rank_size = 1
- logs = get_logger(args.output, rank_id)
- logs.info(f"is distributed {args.is_dist}")
- logs.info(f'Device number {device_num}')
-
- # 加载MPISintel数据集
- dataset_sink_mode = True
- if args.device_target == 'CPU':
- dataset_sink_mode = False
- if args.is_dist:
- dataloader = fetch_dataloader(args, rank_size, rank_id)
- else:
- dataloader = fetch_dataloader(args)
- total_steps = args.num_steps
- ds_size = dataloader.get_dataset_size()
- epoch_num = total_steps//ds_size
- # print("[DEBUG] epoch数: ", epoch_num, ", step数: ", ds_size)
- logs.info(f"epoch number: {epoch_num}, step number: {ds_size}")
- # 计算epoch_num
-
- # mock(dataloader)
- #print(f"[DEBUG] dataset size: {dataloader.get_dataset_size()}")
-
- # 网络
- model = RAFTGMA(args.__dict__)
- if args.restore_ckpt is not None:
- param_dict = load_checkpoint(args.restore_ckpt, strict_load=False)
- load_param_into_net(model, param_dict)
- logs.info(f"Load checkpoint: {args.restore_ckpt}")
- # print("[DEBUG] 加载checkpoint: ", args.restore_ckpt)
- # 损失函数
- loss_fn = SequenceLossCell(args.gamma)
-
- # 优化器
- new_lr = args.lr * device_num # 根据设备数修正学习率
- #[DEBUG] 调学习率
- # lr_scheduler = OneCycleLR(max_lr=new_lr, total_steps=(ds_size*epoch_num+100), pct_start=0.05)
- lr_scheduler = OneCycleLR(max_lr=new_lr, total_steps=(epoch_num*ds_size+100)*2, pct_start=0.05)
-
- lrs = lr_scheduler.get_lrs()
- # print('[DEBUG] ',lrs[0:30], len(lrs))
- logs.info(f"Length of learning list {len(lrs)}")
- lrs = Tensor(np.array(lrs))
-
- optimizer = nn.AdamWeightDecay(params=model.trainable_params(), learning_rate=lrs, weight_decay=args.wdecay, eps=args.epsilon)
- # optimizer = nn.Adam(params = model.trainable_params(), learning_rate = lr_scheduler.get_lrs(), weight_decay = args.wdecay, eps = args.epsilon)
- # print(lr_scheduler.get_lrs()[12004])
- # 构建损失网络
- net_with_criterion = MyWithLossCell(model, loss_fn)
- net_with_criterion.set_train(True)
-
- # Callback
- time_cb = TimeMonitor(data_size=1)
- # loss_cb = GMALossMonitor(log=logs)
- loss_cb = LossMonitor(per_print_times=1)
- callbacks = [time_cb, loss_cb]
- save_epoch = 1 # 源代码是按照 epoch 保存的,但 mindspore 是 iteration
- save_checkpoint_steps = save_epoch * ds_size
- ckpt_config = CheckpointConfig(save_checkpoint_steps=1000,
- keep_checkpoint_max=100) # checkpoints 文件的最大数目:暂定100
- output_directory = args.output
- if args.is_dist:
- output_directory = os.path.join(output_directory, 'ckpt_'+str(get_rank())+'/')
- ckpt_cb = ModelCheckpoint(prefix="GMA", directory=output_directory, config=ckpt_config)
- callbacks.append(ckpt_cb)
- # Profiler Callback
- # prof_cb = StopAtStep(start_step=1, stop_step=4)
- # callbacks.append(prof_cb)
-
- # Define Loss Scale, optimizer and model
- scale_factor = 2
- scale_window = 3000
- loss_scale_manager = ms.DynamicLossScaleManager(init_loss_scale=2.**16,scale_factor=scale_factor,
- scale_window=2000)
- # manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2 ** 16, scale_factor=scale_factor, scale_window=2000)
-
- specified = {"collect_metric": True, "histogram_regular": "^conv1.*|^conv2.*", "collect_graph": True,
- "collect_dataset_graph": True}
- summary_collector = ms.SummaryCollector(summary_dir="./summary_dir/summary_04", collect_specified_data=specified,
- collect_freq=1, keep_default_action=False, collect_tensor_freq=200)
- callbacks.append(summary_collector)
- # train_net = TrainOneStepWithLossScaleCell(net_with_criterion, optimizer, manager)
-
- # 构建训练网络
- # train_net = MyTrainStep(net_with_criterion, optimizer)
- # train_net = CustomTrainOneStepCell(network=net_with_criterion, optimizer=optimizer, clip=args.clip)
- # train_net.set_train()
- ms_model = ms.Model(net_with_criterion, optimizer=optimizer, loss_scale_manager=loss_scale_manager)
- # ms_model = ms.Model(train_net)
- ms_model.train(epoch=epoch_num, train_dataset=dataloader, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
-
-
-
- if __name__=='__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataset_path', default='/root/xidian_wks/cyj/dataset/sintel')
- parser.add_argument('--device_target', default='Ascend')
- parser.add_argument('--name', default='bla', help="name your experiment")
- parser.add_argument('--stage', help="determines which dataset to use for training")
- parser.add_argument('--validation', type=str, nargs='+')
- parser.add_argument('--restore_ckpt', help="restore checkpoint")
- parser.add_argument('--output', type=str, default='checkpoints',
- help='output directory to save checkpoints and plots')
-
- parser.add_argument('--lr', type=float, default=0.00002)
- parser.add_argument('--num_steps', type=int, default=100000)
- parser.add_argument('--batch_size', type=int, default=1)
- parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
-
- parser.add_argument('--wdecay', type=float, default=.00005)
- parser.add_argument('--epsilon', type=float, default=1e-8)
- parser.add_argument('--clip', type=float, default=1.0)
- parser.add_argument('--dropout', type=float, default=0.0)
- parser.add_argument('--upsample-learn', action='store_true', default=False,
- help='If True, use learned upsampling, otherwise, use bilinear upsampling.')
- parser.add_argument('--gamma', type=float, default=0.85, help='exponential weighting')
- parser.add_argument('--iters', type=int, default=12)
- parser.add_argument('--val_freq', type=int, default=10000,
- help='validation frequency')
- parser.add_argument('--print_freq', type=int, default=100,
- help='printing frequency')
-
- parser.add_argument('--mixed_precision', default=False, action='store_true',
- help='use mixed precision')
- parser.add_argument('--model_name', default='', help='specify model name')
-
- parser.add_argument('--position_only', default=False, action='store_true',
- help='only use position-wise attention')
- parser.add_argument('--position_and_content', default=False, action='store_true',
- help='use position and content-wise attention')
- parser.add_argument('--num_heads', default=1, type=int,
- help='number of heads in attention and aggregation')
-
- # dist
- parser.add_argument('--is_dist', default=False, action='store_true')
-
- ms.set_seed(1234)
- np.random.seed(1234)
- random.seed(1234)
- args = parser.parse_args()
- main(args)
-
|