|
- """
- 1 创建model
- 2 准备数据
- 3 训练
- 4 推理
- """
- import os
- import glob
- from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
- from mindspore.nn.optim import Adam
- from mindspore import Model
- from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
-
- # import moxing as mox
-
- from config import Configs
- from get_code_config import load_code_message
- from get_code_mat import get_mats
- from dataset import Dataset
- from layers import BPRNN
-
-
- def train_net():
- # configs init
- cs = Configs()
- # load files and generate code structure
- H, G, n, m, k = load_code_message(cs.H_file, cs.G_file)
- vc_mat, cv_mat, llr_mat, llr_mat_trans, num_edges = get_mats(H, m, n)
- # init network ,loss, optimizers and others
- network = BPRNN(n, m/n, num_edges, cs.batch_size, llr_mat, llr_mat_trans, vc_mat, cv_mat, cs.iterations)
- loss = SoftmaxCrossEntropyWithLogits()
- optimizer = Adam(network.trainable_params())
- model = Model(network, loss, optimizer)
- # set callbacks
- steps = int(cs.train_data_num / cs.batch_size)
- time_cb = TimeMonitor()
- loss_cb = LossMonitor(steps)
- config_ck = CheckpointConfig(save_checkpoint_steps=steps, keep_checkpoint_max=100)
- ckpoint_cb = ModelCheckpoint(prefix='bp-rnn', directory=cs.result_path, config=config_ck)
- callbacks_list = [time_cb, loss_cb, ckpoint_cb]
- # load train data
- data_loader = Dataset(cs.train_data_num, cs.test_data_num, cs.snr_l, cs.snr_h, cs.snr_step, cs.iterations, G, n, m)
- dataset = data_loader.get_train_data(phase='all_zero')
- dataset = dataset.batch(batch_size=cs.batch_size, drop_remainder=True)
- # train
- model.train(cs.epochs, dataset, callbacks=callbacks_list)
-
-
- if __name__ == '__main__':
-
- print(glob.glob('/home/work/user-job-dir/bp-rnn/codes/*'))
-
- train_net()
|