|
- import os
- import json
- import librosa
- import moxing as mox
- import argparse
- from dataset import create_train_dataset
- from data_test1 import DatasetGenerator
- import mindspore.dataset as ds
- from mindspore import Model, load_checkpoint, load_param_into_net
- from mindspore import nn, context, set_seed
- from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.context import ParallelMode
- from network_define import WithLossCell
- from Loss_final1 import loss
- from model_rnn import Dual_RNN_model
- from lr_sch import dynamic_lr
- from train_wrapper import TrainingWrapper
-
- parser = argparse.ArgumentParser(
- description='Parameters for training Dual-Path-RNN')
-
- parser.add_argument('--in_dir', type=str, default=r"/home/work/user-job-dir/inputs/data/",
- help='Directory path of wsj0 including tr, cv and tt')
- parser.add_argument('--out_dir', type=str, default=r"/home/work/user-job-dir/inputs/data_json",
- help='Directory path to put output files')
- parser.add_argument('--data_url',
- help='path to training/inference dataset folder',
- default='/home/work/user-job-dir/inputs/data/')
-
- parser.add_argument('--train_url',
- help='model folder to save/load',
- default='/home/work/user-job-dir/model/')
- parser.add_argument(
- '--device_target',
- type=str,
- default="Ascend",
- choices=['Ascend', 'GPU', 'CPU'],
- help='device where the code will be implemented (default: Ascend)')
-
- parser.add_argument('--train_dir', type=str, default=r"/home/work/user-job-dir/inputs/data_json/tr",
- help='directory including mix.json, s1.json and s2.json')
- parser.add_argument('--sample_rate', default=8000, type=int,
- help='Sample rate')
- parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同
- help='Segment length (seconds)')
- parser.add_argument('--data_batch_size', default=2, type=int, # 需要抛弃的音频长度
- help='Batch size')
- parser.add_argument('--batch_size', type=int, default=2,
- help='Sample rate of audio file')
-
- # Network architecture
- parser.add_argument('--in_channels', default=256, type=int,
- help='The number of expected features in the input')
- parser.add_argument('--out_channels', default=64, type=int,
- help='The number of features in the hidden state')
- parser.add_argument('--hidden_channels', default=128, type=int,
- help='The hidden size of RNN')
- parser.add_argument('--kernel_size', default=2, type=int,
- help='Encoder and Decoder Kernel size')
- parser.add_argument('--rnn_type', default='LSTM', type=str,
- help='RNN, LSTM, GRU')
- parser.add_argument('--norm', default='gln', type=str,
- help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"')
- parser.add_argument('--dropout', default=0.0, type=float,
- help='dropout')
- parser.add_argument('--num_layers', default=6, type=int,
- help='Number of Dual-Path-Block')
- parser.add_argument('--K', default=250, type=int,
- help='The length of chunk')
- parser.add_argument('--num_spks', default=2, type=int,
- help='The number of speakers')
-
- # optimizer
- parser.add_argument('--lr', default=0.001, type=float,
- help='Init learning rate')
- parser.add_argument('--l2', default=1e-5, type=float,
- help='weight decay (L2 penalty)')
-
- # save and load model
- parser.add_argument('--save_folder', default=r"/home/work/user-job-dir/model/",
- help='Location to save epoch models')
- parser.add_argument('--device_num', type=int, default=2,
- help='Sample rate of audio file')
- parser.add_argument('--device_id', type=int, default=0,
- help='Sample rate of audio file')
- parser.add_argument('--run_distribute', type=bool, default=True,
- help='Sample rate of audio file')
-
- def preprocess_one_dir(in_dir, out_dir, out_filename, sample_rate=8000):
- file_infos = []
- in_dir = os.path.abspath(in_dir)
- wav_list = os.listdir(in_dir)
- for wav_file in wav_list:
- if not wav_file.endswith('.wav'):
- continue
- wav_path = os.path.join(in_dir, wav_file)
- samples, _ = librosa.load(wav_path, sr=sample_rate)
- file_infos.append((wav_path, len(samples)))
- if not os.path.exists(out_dir):
- os.makedirs(out_dir)
- with open(os.path.join(out_dir, out_filename + '.json'), 'w') as f:
- json.dump(file_infos, f, indent=4)
-
- def preprocess(args):
- for data_type in ['tr']:
- for speaker in ['mix', 's1', 's2']:
- preprocess_one_dir(os.path.join(args.in_dir, data_type, speaker),
- os.path.join(args.out_dir, data_type),
- speaker,
- sample_rate=args.sample_rate)
- print("preprocess done")
-
- def main(args):
- # context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True)
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
-
- if args.run_distribute:
- print("distribute")
- device_id = int(os.getenv("DEVICE_ID"))
- device_num = args.device_num
- context.set_context(device_id=device_id)
- init()
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
- device_num=device_num)
-
- rank_id = get_rank() # 获取当前设备在集群中的ID
- rank_size = get_group_size() # 获取集群数量
-
- else:
- device_id = args.device_id
- # device_id = int(os.getenv("DEVICE_ID"))
- context.set_context(device_id=device_id)
-
- home = os.path.dirname(os.path.realpath(__file__))
- obs_data_url = args.data_url
- args.data_url = '/home/work/user-job-dir/inputs/data/'
- obs_train_url = args.train_url
- #模型保存目录
- save_folder = os.path.join(home, 'checkpoints') + str(rank_id) # 模型存放路径
- if not os.path.exists(save_folder):
- os.mkdir(save_folder)
-
- save_checkpoint_path = save_folder + '/device_' + os.getenv('DEVICE_ID') + '/'
- if not os.path.exists(save_checkpoint_path):
- os.makedirs(save_checkpoint_path)
- save_ckpt = os.path.join(save_checkpoint_path, 'dualCkpt')
-
- ######################## 将数据集从obs拷贝到训练镜像中 (固定写法)########################
- try:
- mox.file.copy_parallel(obs_data_url, args.data_url)
- print("Successfully Download {} to {}".format(obs_data_url, args.data_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, args.data_url) + str(e))
-
- print("start preprocess ....")
- preprocess(args)
-
- set_seed(42)
- print("Preparing Data")
- # start_time = time.perf_counter()
- # build dataloader
- tr_dataset = DatasetGenerator(args.train_dir, args.data_batch_size,
- sample_rate=args.sample_rate, segment=args.segment)
- tr_loader = create_train_dataset(tr_dataset, args)
- num_steps = tr_loader.get_dataset_size()
- # end_time = time.perf_counter()
- # print("preparing data use: {}min".format((end_time - start_time) / 60))
-
- # build model
- net = Dual_RNN_model(args.in_channels, args.out_channels, args.hidden_channels,
- bidirectional=True, norm=args.norm, num_layers=args.num_layers, dropout=args.dropout, K=args.K)
- print(net)
- # net.set_train()
- # build optimizer
- # lr = dynamic_lr(args.step_per_epoch, args.epoch)
- # optimizer = nn.Adam(net.trainable_params(), learning_rate=args.lr, beta1=0.9, beta2=0.98, weight_decay=args.l2)
- optimizer = nn.Adam(net.trainable_params(), learning_rate=args.lr, weight_decay=args.l2)
- my_loss = loss()
- net_with_loss = WithLossCell(net, my_loss)
- net_with_loss_ = TrainingWrapper(net_with_loss, optimizer)
-
- net_with_loss_.set_train()
-
- model = Model(net_with_loss_)
- # model = Model(net_with_loss, optimizer=optimizer)
-
- time_cb = TimeMonitor()
- loss_cb = LossMonitor(1)
- cb = [time_cb, loss_cb]
-
- config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=5)
- ckpt_cb = ModelCheckpoint(prefix='dual',
- directory=save_ckpt,
- config=config_ck)
- cb += [ckpt_cb]
-
- #开始训练
- print("============== Starting Training ==============")
- model.train(epoch=1, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False)
-
- try:
- mox.file.copy_parallel(save_folder, obs_train_url)
- print("Successfully Upload {} to {}".format(save_folder, obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(save_folder, obs_train_url) + str(e))
- # model.train(epoch=100, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False)
-
- # # 把训练后的模型数据从本地的运行环境拷贝回obs,在启智平台相对应的训练任务中会提供下载
- # try:
- # mox.file.copy_parallel(save_folder, obs_train_url)
- # print("Successfully Upload {} to {}".format(save_folder,
- # obs_train_url))
- # except Exception as e:
- # print('moxing upload {} to {} failed: '.format(save_folder,
- # obs_train_url) + str(e))
-
- if __name__ == '__main__':
- args = parser.parse_args()
- print(args)
- main(args)
|