|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import os
-
- import numpy as np
- import mindspore
- from mindspore import context, Tensor, nn, ops, load_checkpoint, load_param_into_net, DynamicLossScaleManager
- from mindspore.train.model import Model
- from mindspore.train.callback import TimeMonitor
-
- from src.dataset.ShapeNet import create_shapenet_dataset
- from src.model.pointTransformerSeg import get_mode_seg
- from src.config.default import get_config
- from src.utils.common import context_device_init
- from src.utils.callback import CallbackSaveByIoU, CheckLoss
- from src.utils.local_adapter import get_device_id, moxing_wrapper
- from src.utils.metric import IoU, WithEvalCell
- from src.utils.lr_scheduler import MultiStepLR
-
-
-
- class CustomTrainOneStepCell(nn.Cell):
- """自定义训练网络"""
-
- def __init__(self, network, optimizer):
- """入参有两个:训练网络,优化器"""
- super(CustomTrainOneStepCell, self).__init__(auto_prefix=False)
- self.network = network
- self.network.set_grad()
- self.optimizer = optimizer
- self.weights = self.optimizer.parameters
- self.grad = ops.GradOperation(get_by_list=True)
-
- def construct(self, *inputs):
- loss = self.network(*inputs)
- grads = self.grad(self.network, self.weights)(*inputs)
- self.optimizer(grads)
- return loss
-
-
- class CustomWithLossCell(nn.Cell):
- def __init__(self, backbone, loss_fn):
- super(CustomWithLossCell, self).__init__(auto_prefix=False)
- self._backbone = backbone
- self._loss_fn = loss_fn
- self._log_softmax = nn.LogSoftmax()
-
- def construct(self, data, _, label2):
- output = self._backbone(data)
- pred = self._log_softmax(output)
- _, _, C = pred.shape
- pred = pred.view(-1, C)
- label2 = label2.view(-1, 1)[:, 0]
- weight = ops.Ones()((C)).astype(mindspore.float32)
- loss, weight = self._loss_fn(pred, label2, weight)
- return loss
-
-
- @moxing_wrapper()
- def train(cfg):
- cfg.device_id = get_device_id()
- context_device_init(cfg, context.GRAPH_MODE)
- print('Load dataset ...')
- traindataset = create_shapenet_dataset('train', cfg)
-
- step_size = traindataset.get_dataset_size()
-
- net = get_mode_seg()
-
- max_epoch = cfg.epoch_size
-
- if cfg.pretrain_ckpt:
- print('loading pretrain model')
- checkpoint = load_checkpoint(cfg.pretrain_ckpt)
- load_param_into_net(net, checkpoint)
- print('Use pretrain model')
- else:
- print('No existing model, starting training from scratch...')
-
- lr = MultiStepLR(cfg.learning_rate,
- [60, 120, 160],
- 0.1,
- step_size,
- cfg.epoch_size).get_lr()
-
- opt = nn.SGD(params=net.trainable_params(),
- learning_rate=lr,
- momentum=0.9,
- weight_decay=cfg.weight_decay)
-
- eval_dataset = None
- cus_metrics = None
- cus_eval_network = None
-
- save_checkpoint_path = cfg.save_checkpoint_path
- if not os.path.exists(save_checkpoint_path):
- os.makedirs(save_checkpoint_path)
-
- if cfg.run_eval:
- cus_metrics = {'IoU': IoU()}
- cus_eval_network = WithEvalCell(net, True)
- eval_dataset = create_shapenet_dataset('test', cfg)
- eval_proid = cfg.eval_proid
- net_with_criterion = CustomWithLossCell(net, ops.NLLLoss())
- scale_factor = 4
- scale_window = 3000
- loss_scale_manager = DynamicLossScaleManager(scale_factor, scale_window)
- model = Model(net_with_criterion,
- optimizer=opt,
- amp_level="O2",
- eval_network=cus_eval_network,
- metrics=cus_metrics,
- loss_scale_manager=loss_scale_manager)
-
- ckpoint_cb = CallbackSaveByIoU(model, eval_dataset, eval_proid, save_checkpoint_path)
- loss_cb = CheckLoss()
- time_cb = TimeMonitor(step_size)
-
- print("============== Starting Training ==============")
- model.train(max_epoch, traindataset, callbacks=[time_cb, loss_cb, ckpoint_cb])
-
-
- if __name__ == '__main__':
- train(get_config())
|