|
- import os
- os.environ['GLOG_v'] = '3'
- # import mindspore
- from src.parser_util import get_parser
- from mindspore import load_checkpoint
- from src.dataset import OmniglotDataset
- # from prototypical_batch_sampler import PrototypicalBatchSampler
- from src.protonet import ProtoNet
- import mindspore.nn as nn
- from src.PrototypicalLoss import PrototypicalLoss, prototypical_loss_or_acc
- from mindspore.train import Model
- from mindspore.train.callback import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
- from mindspore import dataset as ds
- from src.IterDatasetGenerator import IterDatasetGenerator
- import numpy as np
- from src.metric import PrototyAccLoss
- from src.EvalCallBack import EvalCallBack
- from model_init import init_dataloader
- # import moxing as mox
-
- def train(opt, tr_dataloader, net, loss_fn , optim, val_dataloader=None):
-
- inp = ds.GeneratorDataset(tr_dataloader, column_names=['data', 'label'])
- # train_net = nn.TrainOneStepCell(model, optim)
- model = Model(net, loss_fn, optimizer = optim, metrics={"PrototyAccLoss": PrototyAccLoss()})
-
- eval_data = ds.GeneratorDataset(val_dataloader, column_names=['data', 'label'])
-
- eval_cb = EvalCallBack(model, eval_data, opt.experiment_root)
- config = CheckpointConfig(save_checkpoint_steps=10,
- keep_checkpoint_max=5,
- saved_network=net)
- ckpoint_cb = ModelCheckpoint(prefix='protonet', directory=opt.experiment_root, config=config)
-
- print('==========training test==========')
- model.train(opt.epochs, inp, callbacks=[TimeMonitor(), eval_cb, ckpoint_cb, LossMonitor()])
-
- def test(test_dataloader, net):
- inp = ds.GeneratorDataset(test_dataloader, column_names=['data', 'label'])
- avg_acc = list()
- for epoch in range(10):
- for batch in inp.create_dict_iterator():
- x = batch['data']
- y = batch['label']
- output = net(x)
- _, acc = prototypical_loss_or_acc(output, y, 5)
- print(acc)
- avg_acc.append(acc.asnumpy())
- avg_acc = np.mean(avg_acc)
- print('Test Acc: {}'.format(avg_acc))
-
-
- def main():
- options = get_parser().parse_args()
- if not os.path.exists(options.experiment_root):
- os.makedirs(options.experiment_root)
- # if not os.path.exists('/cache/out'):
- # os.makedirs('/cache/out')
- #
- # mox.file.copy_parallel(src_url=options.data_url, dst_url='/cache/dataset')
- # mox.file.copy_parallel(src_url='/cache/out', dst_url=options.train_url)
- tr_dataloader = init_dataloader(options, 'train', options.dataset_root)
- val_dataloader = init_dataloader(options, 'val', options.dataset_root)
-
- loss_fn = PrototypicalLoss(options.num_support_tr, options.num_query_tr)
- Net = ProtoNet()
- # lr = init_lr_scheduler(options)
- optim = nn.Adam(params=Net.trainable_params(), learning_rate=0.001) # 可能有提升空间
- train(options, tr_dataloader, Net, loss_fn, optim, val_dataloader)
-
- # test(test_dataloader, Net)
-
-
- if __name__ == '__main__':
-
- import mindspore.context as context
-
- context.set_context(device_id=0)
- context.set_context(mode=context.GRAPH_MODE)
- # context.set_context(device_target='CPU')
- main()
|