|
- import numpy as np
- from mindspore import nn
- from mindspore.common.tensor import Tensor
- from mindspore.common import dtype as mstype
- from mindspore.common import Parameter
- from mindspore import ParameterTuple
- from mindspore.ops.composite import GradOperation
- from mindspore.ops import operations as P
- from mindspore import context
- from src.submodels.custom_ops.test_custom import Correlation ,Resample2D
- # from src.submodels.custom_ops.testBatch import Resample2D
-
-
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=5)
-
-
- class Net(nn.Cell):
- def __init__(self):
- super(Net, self).__init__()
- self.matmul = P.MatMul()
- self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
-
- def construct(self, x, y):
- x = x * self.z
- out = self.matmul(x, y)
- return out
-
-
- class GradNetWrtX(nn.Cell):
- def __init__(self, net):
- super(GradNetWrtX, self).__init__()
- self.net = net
- self.grad_op = GradOperation()
-
- def construct(self, x, y):
- gradient_function = self.grad_op(self.net)
- return gradient_function(x, y)
-
- class GradNetWrtXY(nn.Cell):
- def __init__(self, net):
- super(GradNetWrtXY, self).__init__()
- self.net = net
- self.grad_op = GradOperation(get_all=True)
- def construct(self, x, y):
- gradient_function = self.grad_op(self.net)
- return gradient_function(x, y)
-
-
-
- if __name__ == '__main__':
- # x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
- # y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
-
- # x1 = Tensor(np.arange(0, 2 * 256 * 48 * 64).reshape(2, 256, 48, 64).astype(np.float32))
- # x2 = Tensor(np.arange(0, 2 * 256 * 48 * 64).reshape(2, 256, 48, 64).astype(np.float32))
- print('=='*30)
- print('=='*30)
-
- # x3 = Tensor(np.random.random((2, 256, 48, 64)).astype(np.float32))
- # x4 = Tensor(np.random.random((2, 256, 48, 64)).astype(np.float32))
- x3 = Tensor(np.arange(0, 2 * 3 * 48 * 64).reshape(2, 3, 48, 64).astype(np.float32))
- x4 = Tensor(np.arange(0, 2 * 3 * 48 * 64).reshape(2, 3, 48, 64).astype(np.float32))
- outputdddd = GradNetWrtX(Resample2D())(x3, x4)
- print('outputdddd',outputdddd)
|