|
- import tensorflow as tf
- import torch
- import numpy as np
-
-
- def main(model_name: str = "efficientnetv2-s",
- tf_weights_path: str = "./efficientnetv2-s/model",
- stage0_num: int = 2,
- fused_conv_num: int = 10):
-
- except_var = ["global_step"]
-
- new_weights = {}
- var_list = [i for i in tf.train.list_variables(tf_weights_path) if "Exponential" not in i[0]]
- reader = tf.train.load_checkpoint(tf_weights_path)
- for v in var_list:
- if v[0] in except_var:
- continue
- new_name = v[0].replace(model_name + "/", "").replace("/", ".")
-
- if "stem" in v[0]:
- new_name = new_name.replace("conv2d.kernel",
- "conv.weight")
-
- new_name = new_name.replace("tpu_batch_normalization.beta",
- "bn.bias")
- new_name = new_name.replace("tpu_batch_normalization.gamma",
- "bn.weight")
- new_name = new_name.replace("tpu_batch_normalization.moving_mean",
- "bn.running_mean")
- new_name = new_name.replace("tpu_batch_normalization.moving_variance",
- "bn.running_var")
- elif "head" in v[0]:
- new_name = new_name.replace("conv2d.kernel",
- "project_conv.conv.weight")
- new_name = new_name.replace("dense.kernel",
- "classifier.weight")
- new_name = new_name.replace("dense.bias",
- "classifier.bias")
-
- new_name = new_name.replace("tpu_batch_normalization.beta",
- "project_conv.bn.bias")
- new_name = new_name.replace("tpu_batch_normalization.gamma",
- "project_conv.bn.weight")
- new_name = new_name.replace("tpu_batch_normalization.moving_mean",
- "project_conv.bn.running_mean")
- new_name = new_name.replace("tpu_batch_normalization.moving_variance",
- "project_conv.bn.running_var")
- elif "blocks" in v[0]:
- # e.g. blocks_0.conv2d.kernel -> 0
- blocks_id = new_name.split(".", maxsplit=1)[0].replace("blocks_", "")
- new_name = new_name.replace("blocks_{}".format(blocks_id),
- "blocks.{}".format(blocks_id))
-
- if int(blocks_id) <= stage0_num - 1: # expansion=1 fused_mbconv
- new_name = new_name.replace("conv2d.kernel",
- "project_conv.conv.weight")
- new_name = new_name.replace("tpu_batch_normalization.beta",
- "project_conv.bn.bias")
- new_name = new_name.replace("tpu_batch_normalization.gamma",
- "project_conv.bn.weight")
- new_name = new_name.replace("tpu_batch_normalization.moving_mean",
- "project_conv.bn.running_mean")
- new_name = new_name.replace("tpu_batch_normalization.moving_variance",
- "project_conv.bn.running_var")
- else:
- new_name = new_name.replace("blocks.{}.conv2d.kernel".format(blocks_id),
- "blocks.{}.expand_conv.conv.weight".format(blocks_id))
- new_name = new_name.replace("tpu_batch_normalization.beta",
- "expand_conv.bn.bias")
- new_name = new_name.replace("tpu_batch_normalization.gamma",
- "expand_conv.bn.weight")
- new_name = new_name.replace("tpu_batch_normalization.moving_mean",
- "expand_conv.bn.running_mean")
- new_name = new_name.replace("tpu_batch_normalization.moving_variance",
- "expand_conv.bn.running_var")
-
- if int(blocks_id) <= fused_conv_num - 1: # fused_mbconv
- new_name = new_name.replace("blocks.{}.conv2d_1.kernel".format(blocks_id),
- "blocks.{}.project_conv.conv.weight".format(blocks_id))
- new_name = new_name.replace("tpu_batch_normalization_1.beta",
- "project_conv.bn.bias")
- new_name = new_name.replace("tpu_batch_normalization_1.gamma",
- "project_conv.bn.weight")
- new_name = new_name.replace("tpu_batch_normalization_1.moving_mean",
- "project_conv.bn.running_mean")
- new_name = new_name.replace("tpu_batch_normalization_1.moving_variance",
- "project_conv.bn.running_var")
- else: # mbconv
- new_name = new_name.replace("blocks.{}.conv2d_1.kernel".format(blocks_id),
- "blocks.{}.project_conv.conv.weight".format(blocks_id))
-
- new_name = new_name.replace("depthwise_conv2d.depthwise_kernel",
- "dwconv.conv.weight")
-
- new_name = new_name.replace("tpu_batch_normalization_1.beta",
- "dwconv.bn.bias")
- new_name = new_name.replace("tpu_batch_normalization_1.gamma",
- "dwconv.bn.weight")
- new_name = new_name.replace("tpu_batch_normalization_1.moving_mean",
- "dwconv.bn.running_mean")
- new_name = new_name.replace("tpu_batch_normalization_1.moving_variance",
- "dwconv.bn.running_var")
-
- new_name = new_name.replace("tpu_batch_normalization_2.beta",
- "project_conv.bn.bias")
- new_name = new_name.replace("tpu_batch_normalization_2.gamma",
- "project_conv.bn.weight")
- new_name = new_name.replace("tpu_batch_normalization_2.moving_mean",
- "project_conv.bn.running_mean")
- new_name = new_name.replace("tpu_batch_normalization_2.moving_variance",
- "project_conv.bn.running_var")
-
- new_name = new_name.replace("se.conv2d.bias",
- "se.conv_reduce.bias")
- new_name = new_name.replace("se.conv2d.kernel",
- "se.conv_reduce.weight")
- new_name = new_name.replace("se.conv2d_1.bias",
- "se.conv_expand.bias")
- new_name = new_name.replace("se.conv2d_1.kernel",
- "se.conv_expand.weight")
- else:
- print("not recognized name: " + v[0])
-
- var = reader.get_tensor(v[0])
- new_var = var
- if "conv" in new_name and "weight" in new_name and "bn" not in new_name and "dw" not in new_name:
- assert len(var.shape) == 4
- # conv kernel [h, w, c, n] -> [n, c, h, w]
- new_var = np.transpose(var, (3, 2, 0, 1))
- elif "bn" in new_name:
- pass
- elif "dwconv" in new_name and "weight" in new_name:
- # dw_kernel [h, w, n, c] -> [n, c, h, w]
- assert len(var.shape) == 4
- new_var = np.transpose(var, (2, 3, 0, 1))
- elif "classifier" in new_name and "weight" in new_name:
- assert len(var.shape) == 2
- new_var = np.transpose(var, (1, 0))
-
- new_weights[new_name] = torch.as_tensor(new_var)
-
- torch.save(new_weights, "pre_" + model_name + ".pth")
-
-
- if __name__ == '__main__':
- main(model_name="efficientnetv2-s",
- tf_weights_path="./efficientnetv2-s/model",
- stage0_num=2,
- fused_conv_num=10)
-
- # main(model_name="efficientnetv2-m",
- # tf_weights_path="./efficientnetv2-m/model",
- # stage0_num=3,
- # fused_conv_num=13)
-
- # main(model_name="efficientnetv2-l",
- # tf_weights_path="./efficientnetv2-l/model",
- # stage0_num=4,
- # fused_conv_num=18)
|