|
- from __future__ import division
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- from mindspore.context import ParallelMode
- import mindspore.nn as nn
- import mindspore.ops as ops
- from mindspore import Tensor
- from mindspore import context, Model
- import argparse as arg
- import moxing as mox
- import os, glob
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from FullConnectedNetwork import Fnet
- from GetDataset import get_train_data, get_val_data, get_test_data, BpskAddnoise
- import numpy as np
-
-
- if __name__ == '__main__':
- parser = arg.ArgumentParser(description='Mindspore SID Example')
- parser.add_argument('--device_target', default='Ascend',
- help='device where the code will be implemented')
- parser.add_argument('--data_url', required=True, default=None, help='Location of data')
- parser.add_argument('--train_url', required=True, default=None, help='location of training outputs')
- args = parser.parse_args()
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
- device_id = int(os.getenv('DEVICE_ID'))
- device_num = int(os.getenv('RANK_SIZE'))
- context.set_context(device_id=device_id)
- context.set_auto_parallel_context(device_num=device_num,
- parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=True)
- init()
-
- # 设置云环境数据位置并将数据集从obs桶传至云环境
- obs_data_url = args.data_url
- args.data_url = '/home/work/user-job-dir/inputs/data/'
- obs_train_url = args.train_url
- args.train_url = '/home/work/user-job-dir/outputs/model/'
- try:
- mox.file.copy_parallel(obs_data_url, args.data_url)
- print("Successfully Download {} to {}".format(obs_data_url,
- args.data_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- obs_data_url, args.data_url) + str(e))
-
- # 参数设置
- N = 16
- K = 8
- epoch_num = 500
- batch_size = 256
- lr = 0.001
-
- net = Fnet(n=N, k=K) # 实例化网络
- loss = nn.MSELoss() # 损失函数
- # loss = nn.BCE(reduction='mean')
- net_opt = nn.Adam(net.trainable_params(), learning_rate=lr) # 优化器
- # net_opt = nn.SGD(net.trainable_params(), learning_rate=0.1)
- # print(net.trainable_params())
- model = Model(net, loss_fn=loss, optimizer=net_opt) # model封装
-
- config_ck = CheckpointConfig(save_checkpoint_steps=256, keep_checkpoint_max=1) # 定期保存ckpt文件
- ckpoint_cb = ModelCheckpoint(prefix='sony_trained_net', directory=args.train_url, config=config_ck)
- loss_cb = LossMonitor() # 设置训练时监控器
- time_cb = TimeMonitor()
- callbacks = [loss_cb, time_cb, ckpoint_cb]
-
- dataset = get_train_data(local_data_path) # 生成数据集并添加噪声
- dataset = dataset.map(operations=[BpskAddnoise], input_columns=['data', 'label'], output_columns=['data', 'label'])
- dataset = dataset.batch(batch_size=batch_size, drop_remainder=True) # 将数据集打包为batch
- model.train(epoch=epoch_num, train_dataset=dataset, callbacks=callbacks, dataset_sink_mode=False) # 训练
-
- # Load trained net and test
- test_net = Fnet(n=N, k=K)
- local_pa
- ckpt_filepath = glob.glob(args.train_url + '*.ckpt')
- print(ckpt_filepath)
- ckpt_filepath = ckpt_filepath[-1]
- param_dict = load_checkpoint(ckpt_filepath)
- load_param_into_net(test_net, param_dict)
-
- test_x, test_y = get_test_data(local_data_path)
-
- y_pred = test_net(Tensor(test_x))
- y_pred = y_pred.asnumpy()
- y_true = test_y
- print('output shape : ', y_pred.shape, y_true.shape)
- err_num = np.not_equal(y_true, np.round(y_pred))
- ber = np.mean(err_num)
-
- print('The test BER is : ', ber)
-
- mox.file.copy_parallel(src_url=args.train_url, dst_url=obs_train_url)
|