|
- """
- 1 创建model
- 2 准备数据
- 3 训练
- 4 推理
- """
- import os
- from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
- from mindspore.nn.optim import Adam
- from mindspore import Model
- from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
-
- # import moxing as mox
-
- from config import Configs
- from get_code_config import load_code_message
- from get_code_mat import get_mats
- from dataset import Dataset
- from layers import BPRNN
-
-
- def train_net():
- # set path on modelarts
- current_path = os.path.dirname(os.path.realpath(__file__)) # BootfileDirectory, 启动文件所在的目录
- project_root = os.path.dirname(current_path) # 工程的根目录,对应ModelArts训练控制台上设置的代码目录
- # configs init
- cs = Configs()
- # load files and generate code structure
- H_file = cs.code_file_name + '.alist'
- G_file = cs.code_file_name + '.gmat'
- H_file = os.path.join(project_root, H_file)
- G_file = os.path.join(project_root, G_file)
- H, G, n, m, k = load_code_message(H_file, G_file)
- vc_mat, cv_mat, llr_mat, llr_mat_trans, num_edges = get_mats(H, m, n)
- # init network ,loss, optimizers and others
- network = BPRNN(n, m/n, num_edges, cs.batch_size, llr_mat, llr_mat_trans, vc_mat, cv_mat, cs.iterations)
- loss = SoftmaxCrossEntropyWithLogits()
- optimizer = Adam(network.trainable_params())
- model = Model(network, loss, optimizer)
- # set callbacks
- steps = int(cs.train_data_num / cs.batch_size)
- time_cb = TimeMonitor()
- loss_cb = LossMonitor(steps)
- config_ck = CheckpointConfig(save_checkpoint_steps=steps, keep_checkpoint_max=100)
- save_path = os.path.join(project_root, cs.result_path)
- ckpoint_cb = ModelCheckpoint(prefix='bp-rnn', directory=save_path, config=config_ck)
- callbacks_list = [time_cb, loss_cb, ckpoint_cb]
- # load train data
- data_loader = Dataset(cs.train_data_num, cs.test_data_num, cs.snr_l, cs.snr_h, cs.snr_step, cs.iterations, G, n, m)
- dataset = data_loader.get_train_data(phase='all_zero')
- dataset = dataset.batch(batch_size=cs.batch_size, drop_remainder=True)
- # train
- model.train(cs.epochs, dataset, callbacks=callbacks_list)
-
-
- if __name__ == '__main__':
- train_net()
|