|
- # PyTorch enisum
- # torch_time 0.13966515064239501(with random)
- # torch_time 0.026357293128967285
- # torch_time 0.02468111515045166
- # torch_time 0.02991375923156738
- import numpy as np
- import time
- import os
- import mindspore
- from mindspore import context
- from mindspore.ops import Transpose, BatchMatMul, MatMul, Reshape
-
- transpose = Transpose()
- batchMatMul = BatchMatMul()
- matmul = MatMul()
- reshape = Reshape()
-
-
- def generate_random_data(n):
- np.random.seed(1234)
- rw_head_q, w_head_k = [], []
- for i in range(n):
- rw_head_q.append(np.random.rand(512, 22, 8, 64))
- w_head_k.append(np.random.rand(512, 22, 8, 64))
- return rw_head_q, w_head_k
-
-
- def test_torch(rw_head_q, w_head_k):
- rw_head_q_torch = torch.Tensor(rw_head_q)
- w_head_k_torch = torch.Tensor(w_head_k)
- AC_torch = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q_torch, w_head_k_torch))
- return AC_torch
-
-
- def test_ascend_float32(rw_head_q, w_head_k):
- rw_head_q_msp = mindspore.Tensor(rw_head_q, dtype=mindspore.float32)
- print('rw_head_q_msp : ', rw_head_q_msp.shape())
- w_head_k_msp = mindspore.Tensor(w_head_k, dtype=mindspore.float32)
- print('w_head_k_msp : ', w_head_k_msp.shape())
- rw_head_q_msp_t = transpose(rw_head_q_msp, (1, 2, 0, 3)) # 22, 8, 512, 64
- print('rw_head_q_msp_t : ', rw_head_q_msp_t.shape())
- w_head_k_msp_t = transpose(w_head_k_msp, (1, 2, 3, 0)) # 22, 8, 64, 512
- print('w_head_k_msp_t : ', w_head_k_msp_t.shape())
- rw_head_q_msp_t_two_d = reshape(rw_head_q_msp_t, (-1, rw_head_q_msp_t.shape[3]))
- print('rw_head_q_msp_t_two_d : ', rw_head_q_msp_t_two_d.shape())
- w_head_k_msp_t_two_d = reshape(w_head_k_msp_t, (w_head_k_msp_t.shape[2], -1))
- print('w_head_k_msp_t_two_d : ', w_head_k_msp_t_two_d.shape())
- r = matmul(rw_head_q_msp_t_two_d, w_head_k_msp_t_two_d)
- print('r : ', r.shape())
- r_reshape = reshape(r, (rw_head_q_msp_t.shape[0], rw_head_q_msp_t.shape[1], rw_head_q_msp_t.shape[2], w_head_k_msp_t.shape[3]))
- print('r_reshape : ', r_reshape.shape())
- AC_msp = transpose(r_reshape, (2, 3, 0, 1))
- print('AC_msp : ', AC_msp.shape())
- return AC_msp
-
-
- def test_ascend_float16(rw_head_q, w_head_k):
- rw_head_q_msp = mindspore.Tensor(rw_head_q, dtype=mindspore.float32)
- w_head_k_msp = mindspore.Tensor(w_head_k, dtype=mindspore.float32)
- AC_msp = transpose(batchMatMul(transpose(rw_head_q_msp, (1, 2, 0, 3)),
- transpose(w_head_k_msp, (1, 2, 3, 0))), (2, 3, 0, 1))
- # attn_vec = self.transpose(
- # matmul(self.transpose(attn_prob, (2, 3, 0, 1)), self.transpose(w_head_v, (1, 2, 0, 3))), (2, 0, 1, 3))
- return AC_msp
-
-
- def torch2msp(torch_tensor):
- np_tensor = torch_tensor.numpy()
- msp_tensor = mindspore.Tensor(np_tensor, dtype=mindspore.float32)
- return msp_tensor
-
-
- def cmp(msp_result, torch_result):
- diff = msp_result - torch_result
- return diff.sum()
-
-
- if __name__ == '__main__':
- # 生成随机测试数据
- n = 1 # 数据规模 #
- rw_head_q, w_head_k = generate_random_data(n)
-
- # # 测试torch代码
- # start_time = time.time()
- # AC_torch_list, AC_msp_list = [], []
- # for i in range(n):
- # AC_torch = test_torch(rw_head_q[i], w_head_k[i])
- # AC_torch_list.append(AC_torch)
- # torch_time = time.time() - start_time
- # print('torch_time : ', torch_time / n, 's')
- # 转换为mindspore Tensor,用于校验结果
- for i in range(n):
- AC_msp_list.append(torch2msp(AC_torch_list[i]))
-
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=0)
- AC_float16_list, AC_float32_list = [], []
- # 测试ascend代码
- start_time = time.time()
- for i in range(n):
- AC_msp = test_ascend_float32(rw_head_q[i], w_head_k[i])
- AC_float32_list.append(AC_msp)
- msp_time = time.time() - start_time
- print('msp_time(float32) : ', msp_time / n, 's')
-
- # 校验float32结果 #
- # diff = 0.0
- # for i in range(n):
- # diff += cmp(AC_float32_list[i], AC_msp_list[i])
- # print('msp_diff(float32) : ', diff)
-
- start_time = time.time()
- for i in range(n):
- AC_msp = test_ascend_float16(rw_head_q[i], w_head_k[i])
- AC_float16_list.append(AC_msp)
- msp_time = time.time() - start_time
- print('msp_time(float16) : ', msp_time / n, 's')
-
- # 校验float16结果 #
- # diff = 0.0
- # for i in range(n):
- # diff += cmp(AC_float16_list[i], AC_msp_list[i])
- # print('msp_diff(float16) : ', diff)
-
- diff = 0.0
- for i in range(n):
- diff += cmp(AC_float16_list[i], AC_float32_list[i])
- print('msp_diff(float16) : ', diff)
-
- # 测试结果
- # torch_time : 0.4327152013778687 s
- # msp_time(float32) : 5.866198801994324 s
- # msp_diff(float32) : -0.016492367
- # msp_time(float16) : 5.627848029136658 s
- # msp_diff(float16) : -692.88654
-
- # PyTorch einsum
- # torch_time 0.4327152013778687 s
- # MindSpore 使用BatchMatMul和Transpose替代Einsum
- # msp_time(float32) 5.866198801994324 s(float32) GRAPH_MODE
- # msp_time(float16) 5.627848029136658 s(float16) GRAPH_MODE
-
- """
- #
- from mindspore.ops import Einsum
- x = mindspore.Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
- equation = "i->"
- einsum = Einsum(equation)
- output = einsum(*[x])
- print(output)
- """
- # a1 = 694.27
- # a2 = 447.14
- # b1 = a1 * (0.4590 + 0.0591)
- # diff = (0.0247 / 0.6695)
- # c = 694.27 - 694.27 * (0.4590 + 0.0591) * (1 - 0.0247 / 0.6695) # 347.8392459184466
|