@@ -81,9 +81,11 @@ class SolutionTestBase:
def generate_case_params(self):
args_dict = dict()
input_dict = dict()
yaml_content = yaml_read(self.yaml_path)
args = yaml_content.get("args")
variation = yaml_content.get("variation")
input_shape = yaml_content.get("input_shape")
self.params_dict["op_name"] = yaml_content.get("op_name")
self.params_dict["torch_op_name"] = yaml_content.get("torch_op_name")
# 解析yaml配置文件参数
@@ -92,7 +94,7 @@ class SolutionTestBase:
# BC覆盖,生成所有参数组合
self.params_dict = self.bc_generate_args(self.params_dict)
# 获取输入shape
self.params_dict = self.get_input_shape(yaml_content.get(" input_shape") , self.params_dict)
self.params_dict = self.get_input_shape(input_shape, self.params_dict)
# 执行约束条件
if variation is not None:
self.params_dict = self.execute_variation(variation, self.params_dict)
@@ -105,8 +107,12 @@ class SolutionTestBase:
for key in args.keys():
if self.params_dict.get(key) is not None:
args_dict[key] = self.params_dict.get(key)
if input_shape is not None:
for key in input_shape.keys():
if self.params_dict.get(key) is not None:
input_dict[key] = self.params_dict.get(key)
self.ms_log.info("params_dict: %s, args_dict: %s", self.params_dict, args_dict)
return self.params_dict, args_dict
return self.params_dict, args_dict, input_shape
def generate_input_shape(self):
"""
@@ -114,17 +120,36 @@ class SolutionTestBase:
"""
ms_input_shape_list = list()
torch_input_shape_list = list()
params_dict, _ = self.generate_case_params()
for key, value in params_dict.items():
if "input" not in key:
continue
random_shape = np.random.randn(*params_dict.get(key).get("input_shape"))
torch_input_shape = torch.tensor(random_shape, dtype=params_dict.get("torch_dtype"))
ms_input_shape = ms_pytorch.tensor(random_shape, dtype=params_dict.get("ms_dtype"))
_, args_dict, input_dict = self.generate_case_params()
for key, value in input_dict.items():
random_shape = np.random.randn(*input_dict.get(key).get("input_shape"))
torch_input_shape = torch.tensor(random_shape, dtype=input_dict.get("torch_dtype"))
ms_input_shape = ms_pytorch.tensor(random_shape, dtype=input_dict.get("ms_dtype"))
ms_input_shape_list.append(ms_input_shape)
torch_input_shape_list.append(torch_input_shape)
return torch_input_shape_list, ms_input_shape_list
def generate_args_with_tensor(self, args):
"""
生成参数中的tensor类型数据
:param args: 设为tesnor数据的参数名
:param frame_name: 使用的框架名,取值分别为:"torch"或"ms_pytorch"
"""
_, args_dict, _ = self.generate_case_params()
numpy_shape = args_dict.get(args).get("numpy_shape")
numpy_type = args_dict.get(args).get("numpy_type")
random_shape = np.random.randn(*numpy_shape)
torch_tensor = torch.tensor(random_shape, dtype=eval(".".join(["torch", numpy_type])))
args_dict["weight"] = torch_tensor
torch_args = args_dict.copy()
ms_tensor = ms_pytorch.tensor(random_shape, dtype=eval(".".join(["ms_pytorch", numpy_type])))
args_dict["weight"] = ms_tensor
ms_args = args_dict
self.ms_log.info("func generate_args_with_tensor get torch_args: %s\n ms_args: %s\n", torch_args, ms_args)
return torch_args, ms_args
def parse_yaml_args(self, args):
"""
解析yaml文件参数
@@ -136,17 +161,19 @@ class SolutionTestBase:
support_type = value.get("support_type", None)
if not isinstance(value, dict):
raise TypeError("Yaml config key: %s isn't dict type." % key)
if support_type is None:
raise TypeError("Yaml config missing key: %s attributes: support_type" % key)
elif support_value_range is None:
raise TypeError("Yaml config missing key: %s attributes: support_value_range" % key)
if isinstance(support_type, str):
if isinstance(support_value_range, dict):
support_value_range = tuple(support_value_range.get(support_type, [1]))
if isinstance(support_value_length, dict):
support_value_length = tuple(support_value_length.get(support_type, [1]))
params_dict = self.param_value_generate(key, support_type,
support_value_range, support_value_length, params_dict)
if support_type == "numpy":
params_dict = self.numpy_value_generate(key, value, params_dict)
else:
params_dict = self.param_value_generate(key, support_type,
support_value_range,
support_value_length,
params_dict)
elif isinstance(support_type, list):
for v_type in support_type:
if isinstance(support_value_range, dict):
@@ -157,9 +184,41 @@ class SolutionTestBase:
value_length = copy.deepcopy(tuple(support_value_length.get(v_type, [1])))
else:
value_length = tuple(support_value_length)
params_dict = self.param_value_generate(key, v_type,
value_range,
value_length, params_dict)
if v_type == "numpy":
params_dict = self.numpy_value_generate(key, value, params_dict)
else:
params_dict = self.param_value_generate(key, v_type,
value_range,
value_length,
params_dict)
return params_dict
def numpy_value_generate(self, key, value, params_dict):
"""
numpy类型数值生成
:param key: 键名
:param value: 参数
:param params_dict: 最终参数结果
"""
numpy_shape_list = value.get("numpy_shape", list())
numpy_data_type_list = value.get("numpy_type", ["float32"])
if not numpy_shape_list:
return params_dict
if isinstance(numpy_shape_list[0], list):
numpy_shape = random.choice(numpy_shape_list)
else:
numpy_shape = numpy_shape_list
if isinstance(numpy_data_type_list, list):
numpy_data_type = random.choice(numpy_data_type_list)
else:
numpy_data_type = numpy_data_type_list
params_dict[key] = {
"numpy_shape": numpy_shape,
"numpy_type": numpy_data_type
}
return params_dict
def param_value_generate(self, key, data_type, support_value_range, support_value_length, params_dict):
@@ -174,9 +233,14 @@ class SolutionTestBase:
if data_type != 'default':
params_dict[key] = list()
if data_type in ["int", "float", "list", "tuple"]:
normal_value_list = [support_value_range[0], support_value_range[0] + 1,
support_value_range[1] - 1, support_value_range[1],
int(np.median(support_value_range))]
# float类型时,如果右侧数值小于1则不取边界值
if data_type == "float" and support_value_range[1] <= 1:
normal_value_list = [support_value_range[0], support_value_range[1],
float(np.median(support_value_range))]
else:
normal_value_list = [support_value_range[0], support_value_range[0] + 1,
support_value_range[1] - 1, support_value_range[1],
int(np.median(support_value_range))]
if support_value_range[0] == support_value_range[1]:
if data_type == "int":
params_dict[key].append(support_value_range[0])
@@ -217,9 +281,11 @@ class SolutionTestBase:
for i in range(length + 1):
if i != length:
continue
value = random.sample([v for v in range(normal_value_list[0],
normal_value_list[-1] + 1)], i)
value.sort()
normal_value_range_list = [v for v in range(normal_value_list[0],
normal_value_list[-1] + 1)]
if len(normal_value_range_list) < i:
continue
value = random.sample(normal_value_range_list, i)
if value not in params_dict[key]:
params_dict[key].append(value)
else:
@@ -230,9 +296,11 @@ class SolutionTestBase:
for i in range(length + 1):
if i != length:
continue
value = random.sample([v for v in range(normal_value_list[0],
normal_value_list[-1] + 1)], i)
value.sort()
normal_value_range_list = [v for v in range(normal_value_list[0],
normal_value_list[-1] + 1)]
if len(normal_value_range_list) < i:
continue
value = random.sample(normal_value_range_list, i)
if tuple(value) not in params_dict[key]:
params_dict[key].append(tuple(value))
elif data_type == "str":
@@ -241,8 +309,7 @@ class SolutionTestBase:
params_dict[key] = support_value_range
return params_dict
@staticmethod
def bc_generate_args(params_dict):
def bc_generate_args(self, params_dict):
"""
bc覆盖算法,获取所以参数组合
:param params_dict: 最终参数结果
@@ -251,10 +318,15 @@ class SolutionTestBase:
base_params_dict = dict()
# 确定首个参数集合
for key, value in params_dict.items():
base_params_dict[key] = value[0]
if isinstance(value, dict):
base_params_dict[key] = value
else:
base_params_dict[key] = value[0]
args_list.append(base_params_dict)
for key, value in params_dict.items():
tmp_dict = copy.deepcopy(base_params_dict)
if isinstance(value, dict):
continue
for sub_value in value:
if sub_value == base_params_dict.get(key):
continue