|
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore import ops
- from mindspore import context
- import mindspore.numpy as msnp
- from mindspore.common.initializer import initializer, HeNormal, HeUniform, One, Zero
- from mindspore import load_checkpoint,load_param_into_net
- from mindspore.ops import constexpr
- from mindspore import ms_function
- from mindspore import dtype as mstype
- from mindspore import dtype as mstype
- from core.utils.ms_utils import *
-
- from core.update import GMAUpdateBlock
- from core.extractor import BasicEncoder
- from core.corr import CorrBlock
- from core.gma import Attention, Aggregate
- import time
-
-
- class RAFTGMA(nn.Cell):
- def __init__(self, args):
- super().__init__()
- self.args = args
-
- self.hidden_dim = hdim = 128
- self.context_dim = cdim = 128
- args['corr_levels'] = 4
- args['corr_radius'] = 4
-
- if 'dropout' not in self.args.keys():
- self.args['dropout'] = 0
-
- # replace argparser
- for key, value in self.args.items():
- setattr(self, key, value)
-
- # feature network, context network, and update block
- self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.get('dropout',0))
- self.cnet = BasicEncoder(output_dim=hdim + cdim, norm_fn='batch', dropout=args.get('dropout',0))
- self.update_block = GMAUpdateBlock(self.args, hidden_dim=hdim)
- self.att = Attention(args=self.args, dim=cdim, heads=self.num_heads, max_pos_size=160, dim_head=cdim)
- self.corr_block = CorrBlock(args=self.args)
-
- self.tanh = ms.ops.Tanh()
- self.relu = ms.ops.ReLU()
-
- # for upsample flow
- self.upflow8 = Upflow8()
- self.softmax = nn.Softmax(axis=2)
- self.pad_unfold = nn.Pad(paddings=((0,0),(0,0),(1,1),(1,1)))
- self.unfold = nn.Unfold(ksizes=[1,3,3,1],strides=[1,1,1,1],rates=[1,1,1,1],padding='valid')
-
- # for corr pyramid init
- self.bmm = ms.ops.BatchMatMul() #fp16优化
- self.sqrt = ms.ops.Sqrt()
- self.avg_pool2d = nn.AvgPool2d(kernel_size=2, stride=2)
-
- self.cast = ms.ops.Cast()
-
-
-
- def freeze_bn(self):
- for name,cell in self.cells_and_names():
- if isinstance(cell, nn.BatchNorm2d):
- cell.eval()
-
- def corr_pyramid_init(self, fmap1, fmap2):
- corr_pyramid = []
-
- # all pairs correlation
- corr = self.corr(fmap1, fmap2)
-
- batch, h1, w1, dim, h2, w2 = corr.shape
- corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
-
- corr_pyramid.append(corr)
- for i in range(self.corr_levels - 1):
- corr = self.avg_pool2d(corr)
- corr_pyramid.append(corr)
- return corr_pyramid
-
- @staticmethod
- @constexpr
- def corr_dim(dim):
- return ms.Tensor(dim).astype(ms.float32)
-
- def corr(self, fmap1, fmap2):
- batch, dim, ht, wd = fmap1.shape
- fmap1 = fmap1.view(batch, dim, ht * wd)
- fmap2 = fmap2.view(batch, dim, ht * wd)
- corr_x = msnp.swapaxes(fmap1,1,2)
- corr_y = fmap2
- corr_fp16 = self.bmm(self.cast(corr_x, mstype.float16), self.cast(corr_y, mstype.float16)) #fp16
- corr_fp32 = self.cast(corr_fp16, mstype.float32)
- corr = corr_fp32.reshape(batch, ht, wd, 1, ht, wd)
- return corr / self.sqrt(self.corr_dim(dim))
-
- def initialize_flow(self, img):
- """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
- N, C, H, W = img.shape
- coords0 = coords_grid(N, H // 8, W // 8)
- coords1 = coords_grid(N, H // 8, W // 8)
-
- # optical flow computed as difference: flow = coords1 - coords0
- return coords0, coords1
-
- def upsample_flow(self, flow, mask):
- """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
- N, _, H, W = flow.shape
- mask = mask.reshape(N, 1, 9, 8, 8, H, W)
- mask = self.softmax(mask)
- up_flow = self.pad_unfold(8*flow)
- # fp16
- up_flow_fp16 = self.unfold(self.cast(up_flow, mstype.float16))
- up_flow = self.cast(up_flow_fp16, mstype.float32)
- up_flow = msnp.swapaxes(up_flow.reshape(N,9,2,H,W),1,2)
- up_flow = up_flow.reshape(N, 2, 9, 1, 1, H, W)
-
- up_flow = (mask * up_flow).sum(axis=2)
- up_flow = msnp.moveaxis(up_flow,4,2)# (0,1,2,3,4,5)->(0,1,4,2,3,5)->(0,1,4,2,5,3)
- up_flow = msnp.moveaxis(up_flow,-2,-1)
- return up_flow.reshape(N, 2, 8 * H, 8 * W)
-
- def construct(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
- """ Estimate optical flow between pair of frames """
-
- image1 = 2 * (image1 / 255.0) - 1.0
- image2 = 2 * (image2 / 255.0) - 1.0
-
- # 没有调用contiguous()
-
- hdim = self.hidden_dim
- cdim = self.context_dim
-
- # 没有autocast
- # t0 = time.time()
- fmap1, fmap2 = self.fnet([image1, image2])
- fmap1 = fmap1.astype(ms.float32)
- fmap2 = fmap2.astype(ms.float32)
- # t1 = time.time()
- # print('[DEBUG] fnet time: ', t1-t0)
-
- corr_pyramid = self.corr_pyramid_init(fmap1,fmap2)
-
- # run the context network
- cnet = self.cnet(image1)
- net, inp = msnp.array_split(cnet, [hdim], axis=1) #record split
- net = self.tanh(net)
- inp = self.relu(inp)
- attention = self.att(inp)
- # np.save("pred_ms/cnet.npy", cnet.asnumpy())
- # np.save("pred_ms/attention.npy", attention.asnumpy())
- # print("[DEBUG] cnet attention saved...")
-
- coords0, coords1 = self.initialize_flow(image1)
-
- if flow_init is not None:
- coords1 = coords1 + flow_init
-
- flow_predictions = []
- flow_up = None
- # print("[DEBUG] iters", iters)
- for itr in range(iters):
- coords1 = ops.stop_gradient(coords1)
- corr = self.corr_block(coords1, corr_pyramid) # index correlation volume
- flow = coords1 - coords0
- net, up_mask, delta_flow = self.update_block(net, inp, corr, flow, attention)
- coords1 = coords1 + delta_flow
-
- # upsample predictions
- if up_mask is None:
- flow_up = self.upflow8(coords1 - coords0)
- else:
- flow_up = self.upsample_flow(coords1 - coords0, up_mask)
- flow_predictions.append(flow_up)
- return flow_predictions, coords1 - coords0, flow_up
-
-
- # for debug only
- # if __name__=='__main__':
- # import argparse
- # import numpy as np
- # parser = argparse.ArgumentParser()
- # parser.add_argument('--name', default='bla', help="name your experiment")
- # parser.add_argument('--stage', help="determines which dataset to use for training")
- # parser.add_argument('--validation', type=str, nargs='+')
- # parser.add_argument('--restore_ckpt', help="restore checkpoint")
- # parser.add_argument('--output', type=str, default='checkpoints',
- # help='output directory to save checkpoints and plots')
-
- # parser.add_argument('--lr', type=float, default=0.00002)
- # parser.add_argument('--num_steps', type=int, default=100000)
- # parser.add_argument('--batch_size', type=int, default=6)
- # parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
- # parser.add_argument('--gpus', type=int, nargs='+', default=[0, 1])
-
- # parser.add_argument('--wdecay', type=float, default=.00005)
- # parser.add_argument('--epsilon', type=float, default=1e-8)
- # parser.add_argument('--clip', type=float, default=1.0)
- # parser.add_argument('--dropout', type=float, default=0.0)
- # parser.add_argument('--upsample-learn', action='store_true', default=False,
- # help='If True, use learned upsampling, otherwise, use bilinear upsampling.')
- # parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
- # parser.add_argument('--iters', type=int, default=12)
- # parser.add_argument('--val_freq', type=int, default=10000,
- # help='validation frequency')
- # parser.add_argument('--print_freq', type=int, default=100,
- # help='printing frequency')
-
- # parser.add_argument('--mixed_precision', default=False, action='store_true',
- # help='use mixed precision')
- # parser.add_argument('--model_name', default='', help='specify model name')
-
- # parser.add_argument('--position_only', default=False, action='store_true',
- # help='only use position-wise attention')
- # parser.add_argument('--position_and_content', default=False, action='store_true',
- # help='use position and content-wise attention')
- # parser.add_argument('--num_heads', default=1, type=int,
- # help='number of heads in attention and aggregation')
-
- # args = parser.parse_args()
-
- # parser = argparse.ArgumentParser()
- # parser.add_argument('--model', default='checkpoints/gma-sintel.pth', help="restore checkpoint")
- # parser.add_argument('--dataset', default='sintel', help="dataset for evaluation")
- # parser.add_argument('--iters', type=int, default=12)
- # parser.add_argument('--num_heads', default=1, type=int,
- # help='number of heads in attention and aggregation')
- # parser.add_argument('--position_only', default=False, action='store_true',
- # help='only use position-wise attention')
- # parser.add_argument('--position_and_content', default=False, action='store_true',
- # help='use position and content-wise attention')
- # parser.add_argument('--mixed_precision', default=True, help='use mixed precision')
- # parser.add_argument('--model_name')
-
- # Ablations
- # parser.add_argument('--replace', default=False, action='store_true',
- # help='Replace local motion feature with aggregated motion features')
- # parser.add_argument('--no_alpha', default=False, action='store_true',
- # help='Remove learned alpha, set it to 1')
- # parser.add_argument('--no_residual', default=False, action='store_true',
- # help='Remove residual connection. Do not add local features with the aggregated features.')
-
- # args = parser.parse_args()
- # context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
- # context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
- # model = RAFTGMA(args.__dict__)
- # image1 = ms.Tensor(np.ones((3, 3, 368, 768)).astype(np.float32))
- # image2 = ms.Tensor(np.ones((3, 3, 368, 768)).astype(np.float32))
-
-
- # param_dict = torch.load('E:\BaiduNetdiskDownload\GMA_project\GMAms\checkpoints\gma-sintel.pth',map_location='cpu')
- # from collections import OrderedDict
- # def mod(param_dict):
- # new_dict = OrderedDict()
- # for key, value in param_dict.items():
- # if key.startswith('module.'):
- # if 'rel_ind' in key:
- # new_dict[key[7:]] = ms.Parameter(ms.Tensor(value.numpy()).astype(ms.int32))
- # continue
- # if 'norm' in key and 'weight' in key:
- # key=key[:-6]+'gamma'
- # if 'norm' in key and 'bias' in key:
- # key=key[:-4]+'beta'
- # if 'rel_height' in key and 'weight' in key:
- # key=key[:-6]+'embedding_table'
- # if 'rel_width' in key and 'weight' in key:
- # key=key[:-6]+'embedding_table'
- # if 'aggregator' in key and 'gamma' in key:
- # key=key[:-5]+'gamma_aggregate'
- # if 'running_mean' in key:
- # key=key[:-12]+'moving_mean'
- # if 'running_var' in key:
- # key=key[:-11]+'moving_variance'
- # new_dict[key[7:]] = ms.Parameter(ms.Tensor(value.numpy()).astype(ms.float32))
- # return new_dict
- # param_dict = mod(param_dict)
- # load_param_into_net(model, param_dict)
-
-
-
- # output = model(image1,image2)
-
- # print('input:')
- # print(image1.shape, image2.shape)
- # print('output:')
- # for el in output:
- # print(el.shape)
-
- # pred_output = np.concatenate([el.asnumpy() for el in output],1)
- # gt_output = np.load('./output1.npy')
- # print(np.mean(np.abs(pred_output - gt_output)))
|