|
- import os
- import torch
- import argparse
- from dataset_tools.dataset_newer_1012 import TR_label_SlicePairsDataset
- from torch.utils.data import DataLoader
- from CSOMP_modeling_1024 import bpsk_demod_model
-
-
- import torch.optim as optim
- import torch.nn as nn
- from torch.optim import lr_scheduler
- import torch.utils.data as D
-
- simple_pairs_training = True
-
- def count_parameters(model):
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
-
- def get_args():
- """Text generation arguments."""
- parser = argparse.ArgumentParser()
- parser.add_argument('--input_dim', default=2560, type=int)
- parser.add_argument('--output_dim', default=20, type=int)
-
- parser.add_argument('--lr', default=0.015, type=float) # 0.0015 for Adam, 0.015 for SGD
- parser.add_argument('--weight_decay', default=0.01, type=float)
-
- parser.add_argument('--load_path',
- default="/userhome/wave_training_old/"
- "bpsk_demod_MLP_SGD/epoch99_iter299.pth", type=str)
-
- args = parser.parse_args()
- return args
-
- def do_valid(model, valid_dataloader, is_valid=True):
- model = model.eval()
- test_ber = model_test_func(model, valid_dataloader)
- if is_valid:
- print('> [Valid] BER is {}%.'.format(test_ber))
- else:
- print('> [Test] BER is {}%.'.format(test_ber))
-
-
- def do_test(args):
- model, train_dataloader, valid_dataloader, test_dataloader = \
- setup_model_with_datasets_simplePairs(raw_train_data_folder, raw_test_data_folder, args)
- # load from pretrained_model
- state_dict = torch.load(args.load_path)
- model.load_state_dict(state_dict)
- print("> loading pretrained_model from {} passed! >>>".format(args.load_path))
- model.cuda()
- model.eval()
-
- test_ber = model_test_func(model, test_dataloader)
- print('> [Test] BER is {}%.'.format(test_ber))
-
-
- def model_test_func(model, test_dataloader):
- test_ber = 0.0
- with torch.no_grad():
- batch_num = 0
- for i, data in enumerate(test_dataloader):
- if i == 100:
- break
- this_TR_realwav_slice, this_label01_slice = data
-
- this_TR_realwav_slice = this_TR_realwav_slice.float().cuda()
- this_label01_slice = this_label01_slice.cuda()
-
- output = model(this_TR_realwav_slice)
- prediction = (output >= 0.5).long()
- total_nums = this_label01_slice.size(0) * this_label01_slice.size(1)
- mask = (prediction==this_label01_slice)
- this_ber = torch.sum(mask) / float(total_nums) * 100.0
-
- test_ber += this_ber
- batch_num += 1
-
- test_ber = test_ber / batch_num
- return test_ber
-
- def setup_model_with_datasets_simplePairs(raw_train_data_folder, raw_test_data_folder, args):
- dataset = TR_label_SlicePairsDataset()
- train_scale, valid_scale, test_scale = 90, 7, 3
- train_nums = int(len(dataset) // (train_scale + valid_scale + test_scale) * train_scale)
- valid_nums = int(len(dataset) // (train_scale + valid_scale + test_scale) * valid_scale)
- test_nums = len(dataset) - train_nums - valid_nums
- train_dataset, valid_dataset, test_dataset = D.random_split(dataset, [train_nums, valid_nums, test_nums])
- train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True) # , num_workers=20, drop_last=True)
-
- valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
- test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
-
- my_model = bpsk_demod_model(args.input_dim, args.output_dim)
- my_model.cuda()
-
- return my_model, train_dataloader, valid_dataloader, test_dataloader
-
- def train(raw_train_data_folder, raw_test_data_folder, args):
-
- my_model, train_dataloader, valid_dataloader, test_dataloader = \
- setup_model_with_datasets_simplePairs(raw_train_data_folder, raw_test_data_folder, args)
-
- print(f'The model has {count_parameters(my_model):,} trainable parameters')
- optimizer = optim.SGD(my_model.parameters(), lr=args.lr) #, weight_decay=args.weight_decay, betas=(0.9, 0.98), eps=1e-8)
- scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=200, eta_min=2e-3, verbose=False) # eta_min 1e-4 for Adam; 2e-3 for SGD
-
- criterion = nn.BCELoss()
- saving_dir = './bpsk_demod_MLP_SGD_SimplePairs_sorted_sample200'
- if not os.path.exists(saving_dir):
- os.mkdir(saving_dir)
-
- for epoch in range(600):
- epoch_loss = 0
- for (i, data) in enumerate(train_dataloader):
- my_model.train()
- optimizer.zero_grad()
-
- this_TR_realwav_slice, next_label01_slice = data
- this_TR_realwav_slice = this_TR_realwav_slice.float().cuda()
- next_label01_slice = next_label01_slice.cuda()
-
- output = my_model(this_TR_realwav_slice)
- ## matlab bpsk_demod
- loss = criterion(output, next_label01_slice.float())
- loss.backward()
-
- optimizer.step()
- scheduler.step()
- epoch_loss += loss.item()
-
-
- if (i+1) % 100 == 0:
- print('epoch[{}], iter[{}]/[{}],'
- ' loss is: {}, lr is: {}'.format(epoch,
- i,
- len(train_dataloader),
- epoch_loss / (i+1),
- optimizer.state_dict()['param_groups'][0]['lr']
- ))
- if (i+1) % 150 == 0 and (epoch % 10 == 0):
- do_valid(my_model, valid_dataloader, is_valid=True)
- do_valid(my_model, test_dataloader, is_valid=False)
- torch.save(my_model.state_dict(), './{}/epoch{}_iter{}.pth'.format(saving_dir, epoch, i))
-
-
-
- if __name__ == '__main__':
- train_wav_dataset_dir = './raw_data/train/data'
- train_wav_labels_dir = './raw_data/train/labels'
- raw_train_data_folder = './raw_data/train'
- raw_test_data_folder = './raw_data/test'
-
- args = get_args()
-
- train(raw_train_data_folder, raw_test_data_folder, args)
- # do_test(args)
-
-
-
|