|
- import sys
- sys.path.append('core')
-
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore import context
- from mindspore import Tensor, export, load_checkpoint, load_param_into_net, save_checkpoint
-
- import argparse
- import os
- import cv2
- import glob
- import numpy as np
- import torch
- from PIL import Image
- import imageio
- import matplotlib.pyplot as plt
-
- from network import RAFTGMA
- from core.utils import flow_viz
- from core.utils.utils import InputPadder
- import os
-
-
-
- def load_image(imfile):
- img = np.array(Image.open(imfile)).astype(np.uint8)
- img = torch.from_numpy(img).permute(2, 0, 1).float()
- return img[None]
-
-
- def viz(img, flo, flow_dir):
- img = np.moveaxis(img[0],0,-1)
- flo = np.moveaxis(flo[0],0,-1)
-
- # map flow to rgb image
- flo = flow_viz.flow_to_image(flo)
-
- imageio.imwrite(os.path.join(flow_dir, 'flo.png'), flo)
- print(f"Saving optical flow visualisation at {os.path.join(flow_dir, 'flo.png')}")
-
-
- def normalize(x):
- return x / (x.max() - x.min())
-
-
- def demo(args):
- context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
- model = RAFTGMA(args.__dict__)
-
- # load pytorch checkpoint
- param_dict = load_checkpoint(args.model)
- load_param_into_net(model, param_dict)
- # model = load_pytorch_state_dict(model,args.model)
- model.set_train(mode=False)
- print(f"Loaded checkpoint at {args.model}")
-
-
- flow_dir = os.path.join(args.path, args.model_name)
- if not os.path.exists(flow_dir):
- os.makedirs(flow_dir)
-
- images = glob.glob(os.path.join(args.path, '*.png')) + \
- glob.glob(os.path.join(args.path, '*.jpg'))
-
- images = sorted(images)
- for imfile1, imfile2 in zip(images[:-1], images[1:]):
- image1 = load_image(imfile1)
- image2 = load_image(imfile2)
- print(f"Reading in images at {imfile1} and {imfile2}")
- padder = InputPadder(image1.shape)
- image1, image2 = padder.pad(image1, image2)
- image1 = ms.Tensor(image1.numpy()).astype(ms.float32)
- image2 = ms.Tensor(image2.numpy()).astype(ms.float32)
- # print(f"image size: {image1.shape} {image2.shape}")
- # inference
- _, flow_low, flow_up = model(image1, image2, iters=12, test_mode=True)
-
- # convert to ndarray
- flow_low = flow_low.asnumpy()
- flow_up = flow_up.asnumpy()
- image1 = image1.asnumpy()
- print(f"Estimating optical flow...")
-
- # visualization
- viz(image1, flow_up, flow_dir)
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--model', help="restore checkpoint")
- parser.add_argument('--model_name', help="define model name", default="GMA")
- parser.add_argument('--path', help="dataset for evaluation")
- parser.add_argument('--num_heads', default=1, type=int,
- help='number of heads in attention and aggregation')
- parser.add_argument('--position_only', default=False, action='store_true',
- help='only use position-wise attention')
- parser.add_argument('--position_and_content', default=False, action='store_true',
- help='use position and content-wise attention')
- parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
- args = parser.parse_args()
-
- demo(args)
|