|
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """train MaskRcnn and get checkpoint files."""
-
- import time
- import os
- import argparse
- from src.config import config
- from src.mask_rcnn_r50 import MaskTextSpotter_Resnet50
- from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
- from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
- from src.lr_schedule import dynamic_lr
- import mindspore.common.dtype as mstype
- from mindspore import context, Tensor, nn
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
- from mindspore.train import Model
- from mindspore.context import ParallelMode
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore.common import set_seed
- import moxing as mox
-
-
- def set_device(args):
- """Set device and ParallelMode(if device_num > 1)"""
- rank = 0
- # set context and device
- device_target = args.device_target
- # device_num = int(os.environ.get("DEVICE_NUM", "8"))
- device_num = get_group_size()
-
- print("device_num:", device_num)
-
- if device_num > 1:
- if device_target == "Ascend":
- context.set_context(device_id=int(os.environ["DEVICE_ID"]))
- init()
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(
- device_num=device_num,
- parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=True,
- )
- rank = get_rank()
- print("Rank_size:", rank)
- else:
- context.set_context(device_id=args.device_id)
-
- return rank
-
-
- parser = argparse.ArgumentParser(description="MindSpore Lenet Example")
-
-
- parser.add_argument("--data_url",
- help="path to training/inference dataset folder",
- default="./data")
-
- parser.add_argument("--train_url",
- help="model folder to save/load",
- default="./model")
-
- parser.add_argument("--run_distribute", help="1P or 8P training", default=True)
-
- parser.add_argument("--multi_data_url",
- help="model folder to save/load",
- default="")
- parser.add_argument(
- "--device_target",
- type=str,
- default="Ascend",
- choices=["Ascend", "GPU", "CPU"],
- help="device where the code will be implemented (default: Ascend)",
- )
-
- set_seed(1)
-
-
- def get_device_id():
- device_id = os.getenv("DEVICE_ID", "0")
- return int(device_id)
-
-
- def get_device_num():
- device_num = os.getenv("RANK_SIZE", "1")
- return int(device_num)
-
-
- def get_rank_id():
- global_rank_id = os.getenv("RANK_ID", "0")
- return int(global_rank_id)
-
-
- def get_job_id():
- return "Local Job"
-
-
- def modelarts_pre_process():
-
- def unzip(zip_file, save_dir):
- import zipfile
-
- s_time = time.time()
- if not os.path.exists(
- os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
- zip_isexist = zipfile.is_zipfile(zip_file)
- if zip_isexist:
- fz = zipfile.ZipFile(zip_file, "r")
- data_num = len(fz.namelist())
- print("Extract Start...")
- print("unzip file num: {}".format(data_num))
- data_print = int(data_num / 100) if data_num > 100 else 1
- i = 0
- for file in fz.namelist():
- if i % data_print == 0:
- print(
- "unzip percent: {}%".format(int(i * 100 /
- data_num)),
- flush=True,
- )
- i += 1
- fz.extract(file, save_dir)
- print("cost time: {}min:{}s.".format(
- int((time.time() - s_time) / 60),
- int(int(time.time() - s_time) % 60),
- ))
- print("Extract Done")
- else:
- print("This is not zip.")
- else:
- print("Zip has been extracted.")
-
- if config.need_modelarts_dataset_unzip:
- zip_file_1 = os.path.join(config.data_path,
- config.modelarts_dataset_unzip_name + ".zip")
- save_dir_1 = os.path.join(config.data_path)
-
- sync_lock = "/tmp/unzip_sync.lock"
-
- # Each server contains 8 devices as most
- if get_device_id() % min(get_device_num(),
- 8) == 0 and not os.path.exists(sync_lock):
- print("Zip file path: ", zip_file_1)
- print("Unzip file save dir: ", save_dir_1)
- unzip(zip_file_1, save_dir_1)
- print("===Finish extract data synchronization===")
- try:
- os.mknod(sync_lock)
- except IOError:
- pass
-
- while True:
- if os.path.exists(sync_lock):
- break
- time.sleep(1)
-
- print("Device: {}, Finish sync unzip data from {} to {}.".format(
- get_device_id(), zip_file_1, save_dir_1))
- print("#" * 200, os.listdir(save_dir_1))
- print(
- "#" * 200,
- os.listdir(
- os.path.join(config.data_path,
- config.modelarts_dataset_unzip_name)),
- )
-
- # config.coco_root = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
- # config.pre_trained = os.path.join(config.coco_root, config.pre_trained)
- # config.save_checkpoint_path = config.output_path
-
-
- def create_mindrecord_dir(prefix, mindrecord_dir):
- if not os.path.isdir(mindrecord_dir):
- os.makedirs(mindrecord_dir)
- print("YOUR mindrecord_dir: ", mindrecord_dir)
- if config.dataset == "coco":
- if os.path.isdir(config.coco_root):
- print("Create Mindrecord.")
- data_to_mindrecord_byte_image("coco", True, prefix)
- print("Create Mindrecord Done, at {}".format(mindrecord_dir))
- elif config.dataset == "icdar":
- if os.path.isdir(config.icdar_root):
- print("Create ICDAR Mindrecord.")
- data_to_mindrecord_byte_image("icdar", True, prefix)
- print("Create Mindrecord Done, at {}".format(mindrecord_dir))
- else:
- if os.path.isdir(config.IMAGE_DIR) and os.path.exists(
- config.ANNO_PATH):
- print("Create Mindrecord.")
- data_to_mindrecord_byte_image("other", True, prefix)
- print("Create Mindrecord Done, at {}".format(mindrecord_dir))
- else:
- raise Exception("IMAGE_DIR or ANNO_PATH not exits.")
-
-
- def load_pretrained_ckpt(net, load_path, device_target):
- param_dict = load_checkpoint(load_path)
- load_param_into_net(net, param_dict)
- return net
-
- def train_maskrcnn():
- args = parser.parse_args()
- args.device_target = config.device_target
- args.multi_data_url = config.multi_data_url
- config.run_distribute = args.run_distribute
- dataset_sink_mode_flag = True
- if config.run_distribute:
- init()
- rank = get_rank()
- dataset_sink_mode_flag = True
- context.reset_auto_parallel_context()
- parallel_mode = ParallelMode.DATA_PARALLEL
- degree = get_group_size()
- device_num = degree
- context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
- context.set_auto_parallel_context(parameter_broadcast=True)
- else:
- rank = 0
- device_num = 1
- args.data_url = "/home/work/user-job-dir/inputs/data/"
- obs_train_url = args.train_url
- args.train_url = "/home/work/user-job-dir/outputs/model/"
- obs_multi_data_url = args.multi_data_url
- str_new = (obs_multi_data_url.replace("[", "").replace("]", "")
- .replace("{", "").replace("}", "").replace('"', "").split(", "))
- for i in str_new:
- if "dataset_url" in i:
- obs_data_url = i.replace("dataset_url:", "")
- try:
- mox.file.copy_parallel(obs_data_url, args.data_url)
- print("Successfully Download {} to {}".format(
- obs_data_url, args.data_url))
- except IOError:
- print("moxing download {} to {} failed: ".format(obs_data_url, args.data_url) + str(e))
- args.dataset_path = args.data_url
- args.icdar_root = args.data_url
- args.save_checkpoint_path = args.train_url
- args.outputs_dir = os.path.join(args.train_url, "./")
- print("\ntrain.py config:\n", config)
- print("Start train for maskrcnn!")
- if config.enable_modelarts:
- local_data_url = args.dataset_path
- print("copy from {} to {}".format(config.data_url, local_data_url))
- mox.file.copy_parallel(config.data_url, local_data_url)
- print("config.data_url:", config.data_url)
- args.dataset_dir = local_data_url
- args.icdar_root = local_data_url
- prefix = "MaskRcnn.mindrecord"
- mindrecord_dir = os.path.join(args.dataset_dir, config.mindrecord_train_dir)
- mindrecord_file = os.path.join(mindrecord_dir, prefix)
- print(mindrecord_file)
- print("Start create dataset!")
- if rank == 0 and not os.path.exists(mindrecord_file):
- create_mindrecord_dir(prefix, mindrecord_dir)
- with open(os.path.join(mindrecord_dir, "tmp.lock"), "w") as file:
- file.write("tmp_lock")
- while True:
- if os.path.exists(os.path.join(mindrecord_dir, "tmp.lock")):
- break
- else:
- time.sleep(5)
- dataset = create_maskrcnn_dataset(mindrecord_file,
- batch_size=config.batch_size, device_num=device_num, rank_id=rank,)
- dataset_size = dataset.get_dataset_size()
- print("total images num: ", dataset_size)
- print("Create dataset done!")
- net = MaskTextSpotter_Resnet50(config=config)
- net = net.set_train()
- loss = LossNet()
- lr = Tensor(dynamic_lr(config, rank_size=device_num, dsize=dataset_size,
- start_steps=config.pretrain_epoch_size * dataset_size,), mstype.float32,)
- #opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, weight_decay=config.weight_decay, loss_scale=config.loss_scale,)
- optim_sgd = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=0.9)
- net_with_loss = WithLossCell(net, loss)
- if config.run_distribute:
- net = TrainOneStepCell(net_with_loss, optim_sgd,
- sens=config.loss_scale, reduce_flag=True, mean=True, degree=device_num,)
- else:
- net = TrainOneStepCell(net_with_loss, optim_sgd, sens=config.loss_scale)
- if config.save_checkpoint:
- ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs
- *dataset_size, keep_checkpoint_max=config.keep_checkpoint_max,)
- save_checkpoint_path = os.path.join(args.outputs_dir, "ckpt_" + str(rank) + "/")
- print("saving model at:", save_checkpoint_path)
- ckpoint_cb = ModelCheckpoint(prefix="mask_text_spotter_", directory=save_checkpoint_path, config=ckptconfig,)
- time_cb = TimeMonitor(data_size=dataset_size)
- loss_cb = LossCallBack(rank_id=rank, save_path=save_checkpoint_path)
- cb = [time_cb, loss_cb, ckpoint_cb]
- model = Model(net)
- model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=dataset_sink_mode_flag,)
- try:
- mox.file.copy_parallel(args.train_url, obs_train_url)
- print("Successfully Upload {} to {}".format(
- args.train_url, obs_train_url))
- except IOError:
- print("moxing upload {} to {} failed: ".format(
- args.train_url, obs_train_url) + str(e))
-
- if __name__ == "__main__":
- train_maskrcnn()
|