|
- #!/usr/bin/env python
-
- # Created on 2018/12
- # Author: Kaituo XU
-
- import hydra
- from svoice.models.swave import SWave
- from mindspore import Model
- from svoice.data.data_test2_5_5 import DatasetGenerator
- import mindspore.dataset as ds
- from mindspore import nn
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- from svoice.network_define import WithLossCell
- from svoice.models.Loss_final1 import myloss
-
- @hydra.main(config_path="conf", config_name='config.yaml')
- def main(args):
- if args.model == "swave":
- kwargs = dict(args.swave)
- kwargs['sr'] = args.sample_rate
- kwargs['segment'] = args.segment
- net = SWave(**kwargs)
- # milestone = []
- # learning_rates = []
- # for i in range(1, 3000):
- # if(i%2 == 0):
- # milestone.append(i)
- # learning_rates.append(args.lr*(args.step.gamma**(i/2)))
-
- tr_dataset = DatasetGenerator(args.dset.train, args.data_batch_size,
- sample_rate=args.sample_rate, segment=args.segment)
- tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], shuffle=False)
- tr_loader = tr_loader.batch(args.batch_size)
- net = net.set_train()
- # lr = nn.piecewise_constant_lr(milestone, learning_rates)
- optimizier = nn.Adam(net.trainable_params(), learning_rate=args.lr, beta1=0.9, beta2=args.beta2)
- # optimizier = nn.SGD(net.trainable_params(), learning_rate=args.lr, weight_decay=args.beta2)
- # my_loss = loss(args.dset.train)
- time_cb = TimeMonitor()
- my_loss = myloss()
- loss_cb = LossMonitor()
- cb = [time_cb, loss_cb]
- net_with_loss = WithLossCell(net, my_loss)
- model = Model(net_with_loss, optimizer=optimizier)
-
- # if not os.path.isdir(args.dset.results):
- # os.makedirs(args.dset.results)
-
- if args.checkpoint:
- #config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
- config_ck = CheckpointConfig(save_checkpoint_steps=5,
- keep_checkpoint_max=args.keep_checkpoint_max)
- ckpt_cb = ModelCheckpoint(prefix="gdprnn", directory=args.dset.outputs, config=config_ck)
- cb += [ckpt_cb]
- model.train(epoch=10, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False)
-
- # @hydra.main(config_path="conf", config_name='config.yaml')
- # def main(args):
- # run(args)
- # try:
- # run(args)
- # except Exception:
- # print("Some error happened")
- # os._exit(1)
-
- if __name__ == '__main__':
- from mindspore import context
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=2)
- # context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=7, save_graphs=True)
- main()
|