Browse Source

添加 'train_ln_adam.py'

master
unicorn 1 month ago
parent
commit
52bfe12526
1 changed files with 245 additions and 0 deletions
  1. +245
    -0
      train_ln_adam.py

+ 245
- 0
train_ln_adam.py View File

@@ -0,0 +1,245 @@
import sys

sys.path.append('../')
import os
import json
import librosa
import moxing as mox
import argparse
from data_test1 import DatasetGenerator
import mindspore.dataset as ds
from mindspore import Model, load_checkpoint, load_param_into_net
from mindspore import nn, context, set_seed
from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.context import ParallelMode
from network_define import WithLossCell
from Loss_final1 import loss
from model_asteroid import Dual_RNN_model
# from mindspore.nn.dynamic_lr import piecewise_constant_lr
# from lr_sch import dynamic_lr
import time

# from mindspore.profiler import Profiler

parser = argparse.ArgumentParser(
description='Parameters for training Dual-Path-RNN')

parser.add_argument('--in_dir', type=str, default=r"/home/work/user-job-dir/inputs/data/",
help='Directory path of wsj0 including tr, cv and tt')
parser.add_argument('--out_dir', type=str, default=r"/home/work/user-job-dir/inputs/data_json",
help='Directory path to put output files')
parser.add_argument('--data_url',
help='path to training/inference dataset folder',
default='/home/work/user-job-dir/inputs/data/')

parser.add_argument('--train_url',
help='model folder to save/load',
default='/home/work/user-job-dir/model/')
parser.add_argument(
'--device_target',
type=str,
default="Ascend",
choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')

parser.add_argument('--train_dir', type=str, default=r"/home/work/user-job-dir/inputs/data_json/tr",
help='directory including mix.json, s1.json and s2.json')
parser.add_argument('--sample_rate', default=8000, type=int,
help='Sample rate')
parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同
help='Segment length (seconds)')
parser.add_argument('--batch_size', default=3, type=int, # 需要抛弃的音频长度
help='Batch size')
parser.add_argument('--threads', type=int, default=8,
help='number of threads for data loader to use')

# Network architecture
parser.add_argument('--in_channels', default=64, type=int,
help='The number of expected features in the input')
parser.add_argument('--out_channels', default=64, type=int,
help='The number of features in the hidden state')
parser.add_argument('--hidden_channels', default=128, type=int,
help='The hidden size of RNN')
parser.add_argument('--bn_channels', default=128, type=int,
help='Number of channels after the conv1d')
parser.add_argument('--kernel_size', default=2, type=int,
help='Encoder and Decoder Kernel size')
parser.add_argument('--rnn_type', default='LSTM', type=str,
help='RNN, LSTM, GRU')
parser.add_argument('--norm', default='gln', type=str,
help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"')
parser.add_argument('--dropout', default=0.0, type=float,
help='dropout')
parser.add_argument('--num_layers', default=6, type=int,
help='Number of Dual-Path-Block')
parser.add_argument('--K', default=250, type=int,
help='The length of chunk')
parser.add_argument('--num_spks', default=2, type=int,
help='The number of speakers')

# optimizer
parser.add_argument('--optimizer', default='adam', type=str,
choices=['sgd', 'adam'],
help='Optimizer (support sgd and adam now)')
parser.add_argument('--lr', default=5e-4, type=float,
help='Init learning rate')
# parser.add_argument('--lr1', default=5e-4, type=float,
# help='Init learning rate')
# parser.add_argument('--lr2', default=1e-5, type=float,
# help='Init learning rate')
parser.add_argument('--momentum', default=0.0, type=float,
help='Momentum for optimizer')
parser.add_argument('--l2', default=1e-5, type=float,
help='weight decay (L2 penalty)')

# save and load model
parser.add_argument('--save_folder', default=r"/home/work/user-job-dir/model/",
help='Location to save epoch models')
# parser.add_argument('--step_per_epoch', default=7120, type=int,
# help='...')
# parser.add_argument('--epoch', default=100, type=int,
# help='total epoch')
# parser.add_argument('--nEpochs', type=int, default=10, help='number of epochs to train for')
parser.add_argument('--device_num', type=int, default=2,
help='Sample rate of audio file')
parser.add_argument('--device_id', type=int, default=0,
help='Sample rate of audio file')
parser.add_argument('--run_distribute', type=bool, default=True,
help='Sample rate of audio file')

def preprocess_one_dir(in_dir, out_dir, out_filename, sample_rate=8000):
file_infos = []
in_dir = os.path.abspath(in_dir)
wav_list = os.listdir(in_dir)
for wav_file in wav_list:
if not wav_file.endswith('.wav'):
continue
wav_path = os.path.join(in_dir, wav_file)
samples, _ = librosa.load(wav_path, sr=sample_rate)
file_infos.append((wav_path, len(samples)))
if not os.path.exists(out_dir):
os.makedirs(out_dir)
with open(os.path.join(out_dir, out_filename + '.json'), 'w') as f:
json.dump(file_infos, f, indent=4)

def preprocess(args):
for data_type in ['tr']:
for speaker in ['mix', 's1', 's2']:
preprocess_one_dir(os.path.join(args.in_dir, data_type, speaker),
os.path.join(args.out_dir, data_type),
speaker,
sample_rate=args.sample_rate)
print("preprocess done")

def main(args):
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True)

if args.run_distribute:
print("distribute")
device_id = int(os.getenv("DEVICE_ID"))
device_num = args.device_num
context.set_context(device_id=device_id)
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
rank_id = get_rank() # 获取当前设备在集群中的ID
rank_size = get_group_size() # 获取集群数量

else:
device_id = args.device_id
# device_id = int(os.getenv("DEVICE_ID"))
context.set_context(device_id=device_id)
home = os.path.dirname(os.path.realpath(__file__))
obs_data_url = args.data_url
args.data_url = '/home/work/user-job-dir/inputs/data/'
obs_train_url = args.train_url
#模型保存目录
save_folder = os.path.join(home, 'checkpoints') + str(rank_id) # 模型存放路径
if not os.path.exists(save_folder):
os.mkdir(save_folder)
save_checkpoint_path = save_folder + '/device_' + os.getenv('DEVICE_ID') + '/'
if not os.path.exists(save_checkpoint_path):
os.makedirs(save_checkpoint_path)
save_ckpt = os.path.join(save_checkpoint_path, 'dprnnCkpt')

######################## 将数据集从obs拷贝到训练镜像中 (固定写法)########################
try:
mox.file.copy_parallel(obs_data_url, args.data_url)
print("Successfully Download {} to {}".format(obs_data_url, args.data_url))
except Exception as e:
print('moxing download {} to {} failed: '.format(
obs_data_url, args.data_url) + str(e))

print("start preprocess ....")
preprocess(args)

set_seed(42)
print("Preparing Data")
start_time = time.perf_counter()
# build dataloader
tr_dataset = DatasetGenerator(args.train_dir, args.batch_size,
sample_rate=args.sample_rate, segment=args.segment)
tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], num_parallel_workers=args.threads,
shuffle=False, num_shards=rank_size, shard_id=rank_id)
tr_loader = tr_loader.batch(2)
num_steps = tr_loader.get_dataset_size()
end_time = time.perf_counter()
print("preparing data use: {}min".format((end_time - start_time) / 60))

# param_dict = load_checkpoint("/home/heu_MEDAI/zhangyu/project/checkpoint/DPRNN_ckpt_1-11_7120.ckpt")
# build model
net = Dual_RNN_model(args.in_channels, args.out_channels, args.hidden_channels, args.bn_channels,
bidirectional=True, norm=args.norm, num_layers=args.num_layers, dropout=args.dropout, K=args.K)
# load_param_into_net(net, param_dict)
print(net)
net.set_train()
# build optimizer
# lr = dynamic_lr(args.step_per_epoch, args.epoch)
# optimizier = nn.Adam(net.get_parameters(), learning_rate=args.lr, weight_decay=args.l2)
# milestone = [int(args.nEpochs / 3)*num_steps , int(args.nEpochs/3)*2*num_steps , args.nEpochs*num_steps]
# learning_rates = [args.lr, args.lr1 , args.lr2]
# lr = piecewise_constant_lr(milestone, learning_rates)
optimizier = nn.Adam(net.trainable_params(), learning_rate=args.lr, weight_decay=args.l2)
my_loss = loss()
net_with_loss = WithLossCell(net, my_loss)
model = Model(net_with_loss, optimizer=optimizier)

time_cb = TimeMonitor()
loss_cb = LossMonitor(1)
cb = [time_cb, loss_cb]

config_ck = CheckpointConfig(save_checkpoint_steps=num_steps, keep_checkpoint_max=5)
ckpt_cb = ModelCheckpoint(prefix='DPRNN',
directory=save_ckpt,
config=config_ck)
cb += [ckpt_cb]

#开始训练
print("============== Starting Training ==============")
# for i in range(100):
# print("This is {} epoch: ".format(i + 1))
# model.train(epoch=1, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False)

# try:
# mox.file.copy_parallel(save_folder, obs_train_url)
# print("Successfully Upload {} to {}".format(save_folder, obs_train_url))
# except Exception as e:
# print('moxing upload {} to {} failed: '.format(save_folder, obs_train_url) + str(e))
model.train(epoch=17, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False)

try:
mox.file.copy_parallel(save_folder, obs_train_url)
print("Successfully Upload {} to {}".format(save_folder, obs_train_url))
except Exception as e:
print('moxing upload {} to {} failed: '.format(save_folder, obs_train_url) + str(e))

if __name__ == '__main__':
# context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
args = parser.parse_args()
print(args)
main(args)

Loading…
Cancel
Save