|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import numpy as np
- import mindspore as ms
- from mindspore import context, nn, ops, load_checkpoint, load_param_into_net
- from mindspore.train.model import Model
- from mindspore.train.callback import TimeMonitor
-
- from src.config.default import get_config
- from src.dataset.ShapeNet import create_shapenet_dataset
- from src.utils.metric import IoU, WithEvalCell
- from src.utils.common import context_device_init
- from src.utils.local_adapter import get_device_id
- from src.model.pointTransformerSeg import get_mode_seg
-
- class CustomWithLossCell(nn.Cell):
- def __init__(self, backbone, loss_fn):
- super(CustomWithLossCell, self).__init__(auto_prefix=False)
- self._backbone = backbone
- self._loss_fn = loss_fn
- self._log_softmax = nn.LogSoftmax()
-
- def construct(self, data, _, label2):
- output = self._backbone(data)
- pred = self._log_softmax(output)
- _, _, C = pred.shape
- pred = pred.view(-1, C)
- label2 = label2.view(-1, 1)[:, 0]
- weight = ms.Tensor(np.ones((C)), ms.float32)
- loss, weight = self._loss_fn(pred, label2, weight)
- return loss
-
- def test(cfg):
-
- if not cfg.device_id:
- cfg.device_id = get_device_id()
- context_device_init(cfg, context.GRAPH_MODE)
-
- eval_dataset = create_shapenet_dataset('test', cfg)
- net = get_mode_seg()
- checkpoint = load_checkpoint(cfg.pretrain_ckpt)
-
- load_param_into_net(net, checkpoint, strict_load=True)
- net.set_train(False)
-
- cus_metrics = {'IoU': IoU()}
- cus_eval_network = WithEvalCell(net, True)
- net_with_criterion = CustomWithLossCell(net, ops.NLLLoss())
-
- model = Model(net_with_criterion, eval_network=cus_eval_network, metrics=cus_metrics)
-
- time_cb = TimeMonitor(eval_dataset.get_dataset_size())
- result = model.eval(eval_dataset, dataset_sink_mode=False, callbacks=time_cb)
- cls_mIou = result['IoU'][0]
- ins_mIoU = result['IoU'][1]
- print(f"ins. mIoU is {ins_mIoU}, cat. mIoU is {cls_mIou}")
-
- if __name__ == '__main__':
- test(get_config())
|