|
- import random
- import numpy as np
- import torch
- import argparse
- import mindspore as ms
- from mindspore import context
- import mindspore.nn as nn
- import mindspore.ops as ops
- import mindspore.dataset as ds
- from mindspore.profiler import Profiler
- from mindspore.train.callback import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
- from mindspore import dtype as mstype
- from core.network import RAFTGMA
- from core.utils.ms_utils import load_pytorch_state_dict
-
- import sys
- sys.path.append('core')
- from core.ms_datasets import MpiSintel, fetch_dataloader
- from core.seq_loss import MyTrainStep, MyWithLossCell, SequenceLossCell, CustomTrainOneStepCell
- from mindspore import Tensor, export, load_checkpoint, load_param_into_net, save_checkpoint
-
-
-
- def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataset_path', default='/root/xidian_wks/cyj/dataset/sintel')
- parser.add_argument('--device_target', default='Ascend')
- parser.add_argument('--model', help="restore checkpoint")
- parser.add_argument('--dataset', help="dataset for evaluation")
- parser.add_argument('--iters', type=int, default=12)
- parser.add_argument('--num_heads', default=1, type=int,
- help='number of heads in attention and aggregation')
- parser.add_argument('--position_only', default=False, action='store_true',
- help='only use position-wise attention')
- parser.add_argument('--position_and_content', default=False, action='store_true',
- help='use position and content-wise attention')
- parser.add_argument('--mixed_precision', default=True, help='use mixed precision')
- parser.add_argument('--model_name')
-
- # Ablations
- parser.add_argument('--replace', default=False, action='store_true',
- help='Replace local motion feature with aggregated motion features')
- parser.add_argument('--no_alpha', default=False, action='store_true',
- help='Remove learned alpha, set it to 1')
- parser.add_argument('--no_residual', default=False, action='store_true',
- help='Remove residual connection. Do not add local features with the aggregated features.')
- args = parser.parse_args()
- return args
-
-
- def export_mindir():
- context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
- args = parse_args()
- model = RAFTGMA(args.__dict__)
- if args.model[-4:] == 'ckpt':
- param_dict = load_checkpoint(args.model)
- load_param_into_net(model, param_dict)
- else:
- model = load_pytorch_state_dict(model,args.model)
- print("[DEBUG] 加载checkpoint: ", args.restore_ckpt)
- model.set_train(mode=False)
- img1 = np.load('image1.npy')
- print(img1.shape)
- img2 = np.load('image2.npy')
- print(img2.shape)
-
- export(model, Tensor(img1), Tensor(img2), file_name='gma_sintel_6', file_format='MINDIR')
-
- if __name__=='__main__':
- ms.set_seed(1234)
- np.random.seed(1234)
- random.seed(1234)
-
- export_mindir()
|