|
- # import hydra
- import os
- import json
- import librosa
- import moxing as mox
- import argparse
-
- from data import DatasetGenerator, DistributedSampler
- # os.environ["LD_PRELOAD"] = "_check_build.cpython-37m-aarch64-linux-gnu.so__init__.py"
- from svoice.models.swave import SWave
- from mindspore import Model
- # from svoice.data.data_test2_5_5 import DatasetGenerator
- from mindspore import save_checkpoint, set_seed, load_checkpoint, load_param_into_net
- from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext
- import mindspore.dataset as ds
- from mindspore import nn
- from mindspore import log as logger
- # from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- # from svoice.network_define import WithLossCell
- # from svoice.models.Loss_final1 import myloss
- from generatorloss import Generatorloss
- from trainonestep import TrainOneStep
- from svoice.network_define import WithLossCell
- from svoice.models.Loss_final1 import myloss
- import time
- import zipfile
- from mindspore import context
- from mindspore.context import ParallelMode
- from mindspore.communication.management import init, get_rank, get_group_size
-
- parser = argparse.ArgumentParser("WSJ0 data preprocessing")
- parser.add_argument('--in-dir', type=str, default=r"/home/work/user-job-dir/inputs/data/",
- help='Directory path of wsj0 including tr, cv and tt')
- parser.add_argument('--out-dir', type=str, default=r"/home/work/user-job-dir/inputs/data_json",
- help='Directory path to put output files')
- parser.add_argument('--sample-rate', type=int, default=8000,
- help='Sample rate of audio file')
- parser.add_argument('--data_url',
- help='path to training/inference dataset folder',
- default='/home/work/user-job-dir/inputs/data/')
- parser.add_argument('--train_url',
- help='model folder to save/load',
- default='/home/work/user-job-dir/model/')
- parser.add_argument(
- '--device_target',
- type=str,
- default="Ascend",
- choices=['Ascend', 'GPU', 'CPU'],
- help='device where the code will be implemented (default: Ascend)')
- parser.add_argument('--segment', type=int, default=4,
- help='Sample rate of audio file')
- parser.add_argument('--batch_size', type=int, default=6,
- help='Sample rate of audio file')
- parser.add_argument('--epochs', type=int, default=100,
- help='Sample rate of audio file')
- parser.add_argument('--device_num', type=int, default=8,
- help='Sample rate of audio file')
- parser.add_argument('--device_id', type=int, default=0,
- help='Sample rate of audio file')
- parser.add_argument('--run_distribute', type=bool, default=True,
- help='Sample rate of audio file')
- parser.add_argument('--data_batch_size', type=int, default=3,
- help='Sample rate of audio file')
- parser.add_argument('--train', type=str, default='/home/work/user-job-dir/inputs/data_json/tr',
- help='Sample rate of audio file')
- parser.add_argument('--valid', type=str, default="/home/work/user-job-dir/inputs/data_json/tr",
- help='Sample rate of audio file')
- parser.add_argument('--test', type=str, default="/home/work/user-job-dir/inputs/data_json/tr",
- help='Sample rate of audio file')
- parser.add_argument('--lr', type=float, default=5e-4,
- help='Sample rate of audio file')
- parser.add_argument('--gamma', type=float, default=0.98,
- help='Sample rate of audio file')
- parser.add_argument('--beta2', type=float, default=0.999,
- help='Sample rate of audio file')
- parser.add_argument('--snapshots', type=int, default=1, help='Snapshots')
- parser.add_argument('--prefix', default='tpami_residual_filter8', help='Location to save checkpoint models')
- parser.add_argument('--model_type', type=str, default='swave')
- parser.add_argument('--pretrained_sr', default='25_gdprnn.ckpt', help='sr pretrained base model')
-
- #asd
- def train(trainoneStep, data, train_dir, obs_train_url, args):
- trainoneStep.set_train()
- trainoneStep.set_grad()
- tr_loader = data['tr_loader']
- # cv_loader = data['cv_loader']
- # tt_loader = data['tt_loader']
- step = tr_loader.get_dataset_size()
-
- # obs_data_url = args.data_url
- # args.data_url = '/home/work/user-job-dir/inputs/data/'
- # obs_train_url = args.train_url
- # args.train_url = '/home/work/user-job-dir/outputs/model/'
- # train_dir = args.train_url
-
- for epoch in range(args.epochs):
-
- total_loss = 0
- j = 0
- for data in tr_loader:
- mixture, len, source = [x for x in data]
- t0 = time.time()
- # print("''''''''''''准备输出loss''''''''''''''''''''''''")
- loss = trainoneStep(mixture, len, source)
- # loss = self.network(mixture, len, source, cross_valid)
-
- # print("输出loss: ", loss)
- t1 = time.time()
- # print("第{}次trainonestp共花费时间:".format(), t1 - t0)
- if j%100==0:
- print("epoch[{}]({}/{}),loss:{:.4f},stepTime:{}".format(epoch + 1, j+1, step, loss.asnumpy(), t1 - t0))
-
- j = j + 1
- total_loss += loss
- train_loss = total_loss/j
- # train_loss = trainoneStep(epoch, tr_loader, False)
- print("epoch[{}]:trainAvgLoss:{:.4f}".format(epoch + 1, train_loss.asnumpy()))
-
- # valid_loss = trainoneStep(epoch, cv_loader, True)
- # print("epoch[{}]:validAvgLoss:{:.4f}".format(epoch + 1, valid_loss.asnumpy()))
-
- # if (epoch + 1) % args.eval_every == 0 or epoch == args.epochs - 1:
- # test_loss = trainoneStep(epoch, tt_loader, cross_valid=True)
- # print("epoch[{}]:testAvgLoss:{:.4f}".format(epoch + 1, test_loss.asnumpy()))
-
- # save_ckpt = os.path.join(args.train_url, '{}_gdprnn.ckpt'.format(epoch + 1))
- # save_checkpoint(trainoneStep.network, save_ckpt)
-
- # if (epoch+1) % (args.snapshots) == 0:
-
- print('===> Saving model')
- save_checkpoint_path = train_dir + '/device_' + os.getenv('DEVICE_ID') + '/'
- if not os.path.exists(save_checkpoint_path):
- os.makedirs(save_checkpoint_path)
- save_ckpt = os.path.join(save_checkpoint_path, '{}_gdprnn.ckpt'.format(epoch + 1))
- save_checkpoint(trainoneStep.network, save_ckpt)
-
- # if environment == 'train':
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,
- obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,
- obs_train_url) + str(e))
-
-
-
- def preprocess_one_dir(in_dir, out_dir, out_filename, sample_rate=8000):
- """
- sample_rate: 8000
- Read the wav file and save the path and len to the json file
- """
- file_infos = []
- in_dir = os.path.abspath(in_dir)
- wav_list = os.listdir(in_dir)
- for wav_file in wav_list:
- if not wav_file.endswith('.wav'):
- continue
- wav_path = os.path.join(in_dir, wav_file)
- samples, _ = librosa.load(wav_path, sr=sample_rate)
- # if len(samples) > 128000:
- # continue
- file_infos.append((wav_path, len(samples)))
- if not os.path.exists(out_dir):
- os.makedirs(out_dir)
- with open(os.path.join(out_dir, out_filename + '.json'), 'w') as f:
- json.dump(file_infos, f, indent=4)
-
- def preprocess(args):
- """ Process all files """
- for data_type in ['tr']:
- for speaker in ['mix', 's1', 's2']:
- preprocess_one_dir(os.path.join(args.in_dir, data_type, speaker),
- os.path.join(args.out_dir, data_type),
- speaker,
- sample_rate=args.sample_rate)
-
- def poly_lr(base_lr, epoch_steps, total_steps):
- for i in range(total_steps):
- if i % epoch_steps == 0:
- new_lr = base_lr * 0.98
- yield new_lr
-
-
- #@hydra.main(config_path="conf", config_name='config.yaml')
- def main(args):
- if args.run_distribute:
- print("distribute")
- device_id = int(os.getenv("DEVICE_ID"))
- device_num = args.device_num
- context.set_context(device_id=device_id)
- init()
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
- device_num=device_num)
-
- rank_id = get_rank() # 获取当前设备在集群中的ID
- rank_size = get_group_size() # 获取集群数量
-
- # train_set = get_training_set(opt.data_dir, opt.hr_train_dataset, opt.upscale_factor, opt.patch_size, opt.data_augmentation)
- # training_data_loader = ds.GeneratorDataset(source=train_set, column_names=["input", "target", "bicubic"],
- # num_parallel_workers=opt.threads, shuffle=True,
- # num_shards=rank_size, shard_id=rank_id)
- else:
- device_id = args.device_id
- # device_id = int(os.getenv("DEVICE_ID"))
- context.set_context(device_id=device_id)
- # 定义模型
-
-
- print("--------------------MAIN----------------------------")
- environment = "train"
-
- # data
- home = os.path.dirname(os.path.realpath(__file__))
- # data_dir = os.path.join(home, 'data') # 数据集存放路径
- obs_data_url = args.data_url
- args.data_url = '/home/work/user-job-dir/inputs/data/'
- train_dir = os.path.join(home, 'checkpoints') + str(rank_id) # 模型存放路径
-
- # 初始化数据存放目录
- # if not os.path.exists(data_dir):
- # os.mkdir(data_dir)
- # 初始化模型存放目录
- obs_train_url = args.train_url
- if not os.path.exists(train_dir):
- os.mkdir(train_dir)
-
- # obs_data_url = args.data_url
- # args.data_url = '/home/work/user-job-dir/inputs/data/'
- # obs_train_url = args.train_url
- # args.train_url = '/home/work/user-job-dir/outputs/model/'
- # train_dir = args.train_url
- # if not os.path.exists(train_dir):
- # os.makedirs(train_dir)
-
- #将数据拷贝到训练环境
- try:
- mox.file.copy_parallel(obs_data_url, args.data_url)
- print("Successfully Download {} to {}".format(obs_data_url,
- args.data_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- obs_data_url, args.data_url) + str(e))
-
-
- # if args.model == "swave":
- # kwargs = dict(args.swave)
- # kwargs['sr'] = args.sample_rate
- # kwargs['segment'] = args.segment
- # net = SWave(**kwargs)
- kwargs = {'N': 128, 'L': 8, 'H': 128, 'R': 6, 'C': 2, 'input_normalize': False, 'sr': 8000, 'segment': 4}
- net = SWave(**kwargs)
-
- #加载ckpt
- flag = False
- if flag:
- ckpt = os.path.join(home, args.pretrained_sr)
- param_dict = load_checkpoint(ckpt)
- load_param_into_net(net, param_dict)
- #
- # 加载数据集
- # unzip(arg.zip_in_dir, arg.zip_out_dir)
- print("开始prepro-------------")
- preprocess(args)
-
- tr_dataset = DatasetGenerator(args.train, args.data_batch_size, sample_rate=args.sample_rate, segment=args.segment)
- # distributed_sampler = DistributedSampler(14240)
- # tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], shuffle=True, num_parallel_workers=8, num_shards=rank_size, shard_id=rank_id, sampler=distributed_sampler)
- tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], shuffle=True, num_parallel_workers=8, num_shards=rank_size, shard_id=rank_id)
- tr_loader = tr_loader.batch(args.batch_size)
- print("结束prepro------------------------")
-
- # cv_dataset = DatasetGenerator(args.valid, args.data_batch_size,
- # sample_rate=args.sample_rate, segment=args.segment)
- # cv_loader = ds.GeneratorDataset(cv_dataset, ["mixture", "lens", "sources"], shuffle=True, num_shards=rank_size, shard_id=rank_id)
- # cv_loader = cv_loader.batch(args.batch_size)
-
- # tt_dataset = DatasetGenerator(args.test, args.data_batch_size,
- # sample_rate=args.sample_rate, segment=args.segment)
- # tt_loader = ds.GeneratorDataset(tt_dataset, ["mixture", "lens", "sources"], shuffle=True, num_shards=rank_size, shard_id=rank_id)
- # tt_loader = tt_loader.batch(args.batch_size)
-
- # data = {"tr_loader": tr_loader,
- # "cv_loader": cv_loader, "tt_loader": tt_loader}
- data = {"tr_loader": tr_loader}
-
- # for _ in tr_loader:
- # print(_)
-
- # loss
- # my_loss = myloss()
- loss_network = Generatorloss(net)
- # loss_network = WithLossCell(net, my_loss)
-
- # 定义优化器
- # milestone = []
- # learning_rates = []
- # for i in range(1, 101):
- # if(i%2 == 0):
- # milestone.append(i)
- # learning_rates.append(args.lr*(args.gamma**(i/2)))
- # lr = nn.piecewise_constant_lr(milestone, learning_rates)
-
- # iter_lr = poly_lr(args.lr, 7120 * 2, 7120 * 100)
- # optimizier = nn.SGD(net.trainable_params())
- optimizier = nn.Adam(net.trainable_params(), learning_rate=args.lr, beta1=0.9, beta2=args.beta2)
-
-
- # Save model config
- # hostname = str(socket.gethostname())
- # ckpt_config = CheckpointConfig(save_checkpoint_steps=args.snapshots)
- # ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=train_dir,
- # prefix='Model_'+args.model_type+hostname+args.prefix)
- # cb_params = _InternalCallbackParam()
- # cb_params.train_network = net
- # cb_params.cur_step_num = 0
- # cb_params.batch_num = args.batch_size
- # cb_params.cur_epoch_num = 0
-
- # run_context = RunContext(cb_params)
- # ckpt_cb.begin(run_context)
-
-
- # 前向到loss
- trainonestepNet = TrainOneStep(loss_network, optimizier, sens=1.0)
- # trainonestepNet = nn.TrainOneStepCell(loss_network, optimizier, sens=1.0)
- # trainonestepNet.set_train()
-
- # train(trainonestepNet, data, cb_params, ckpt_cb, run_context, args)
- train(trainonestepNet, data, train_dir, obs_train_url, args)
-
- # train_url = './model'
- # if environment == 'train':
- # try:
- # mox.file.copy_parallel(train_url, obs_train_url)
- # print("Successfully Upload {} to {}".format(train_url,
- # obs_train_url))
- # except Exception as e:
- # print('moxing upload {} to {} failed: '.format(train_url,
- # obs_train_url) + str(e))
-
-
- if __name__ == '__main__':
- args = parser.parse_args()
- # set_seed(42)
- # os.environ['GLOG_v'] = '1'
- # config= logger.get_log_config()
- # print(config)
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
- print("---------------cont------------")
- # context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=7, save_graphs=True)
- main(args)
|