liupengfei 1 month ago
parent
commit
a1cd02c8d6
2 changed files with 26 additions and 18 deletions
  1. +1
    -1
      default_config.yaml
  2. +25
    -17
      train.py

+ 1
- 1
default_config.yaml View File

@@ -27,7 +27,7 @@ train_data: "FlyingChairs" # Train Dataset name
train_data_path: "/home/work/user-job-dir/data/FlyingChairs_part" # Train Dataset path

# Train Setup
run_distribute: 1 # Distributed training or not
run_distribute: 0 # Distributed training or not
is_save_on_master: 1 # Only save ckpt on master device
save_checkpoint: 1 # Is save ckpt while training
save_ckpt_interval: 2 # Saving ckpt interval


+ 25
- 17
train.py View File

@@ -47,11 +47,8 @@ import argparse
def set_save_ckpt_dir():
"""set save ckpt dir"""
ckpt_save_dir = config.save_checkpoint_path
if config.enable_modelarts and config.run_distribute:
ckpt_save_dir = ckpt_save_dir + "/ckpt_" + str(get_rank_id()) + "/"
else:
if config.run_distribute:
ckpt_save_dir = ckpt_save_dir + "/ckpt_" + str(get_rank()) + "/"
if config.run_distribute:
ckpt_save_dir = ckpt_save_dir + "/ckpt_" + str(get_rank()) + "/"
return ckpt_save_dir


@@ -151,21 +148,32 @@ def run_train():
# profiles = Profiler()
ds.config.set_enable_shared_mem(False)
device_num = get_device_num()
if config.device_target == "Ascend":
config.device_id = get_device_id()
# TODO lpf rank和 group_size 的设置是否有问题
rank = get_rank()
group_size = get_group_size()
# ms.set_context(device_id=config.device_id)
if device_num > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
if config.device_target == "Ascend":
config.device_id = get_device_id()
if config.run_distribute == 1:
init()

device_num = get_group_size()
print('device_num ccc =',device_num)
parallel_mode = ParallelMode.DATA_PARALLEL
rank = get_rank()
group_size = get_group_size()
print('rank ccc =',rank)
print('group_size ccc =',group_size)

# context.reset_auto_parallel_context()
# context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
# gradients_mean=True)
else:
print('log this ')
parallel_mode = ParallelMode.STAND_ALONE
rank = 0
group_size = 1
print('rank ddd =',rank)
print('group_size ddd =',group_size)

context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=group_size)
# load dataset by config param
config.training_dataset_class = tools.module_to_dict(datasets)[config.train_data]
print('config.train_data_path =',config.train_data_path )
flownet_train_gen = config.training_dataset_class(config.crop_type, config.crop_size, config.eval_size,
config.train_data_path)
sampler = datasets.DistributedSampler(flownet_train_gen, rank=rank, group_size=group_size, shuffle=True)


Loading…
Cancel
Save