|
- # -*- coding: utf-8 -*-
- """
- @author: huangxs
- @License: (C)Copyright 2021, huangxs
- @CreateTime: 2021/11/16 19:10:00
- @Filename: train
-
- """
- import os
- import numpy as np
-
- # 设置临时环境变量,只输出error日志
- from src.han_net import HanNet
- from src.utils.metrics_util import accuracy_pixel_level
-
- os.environ['GLOG_v'] = "3"
- # os.environ['DEVICE_ID'] = "7"
-
- from collections import OrderedDict
-
- import mindspore.dataset as ds
-
- from src.utils.dataset import MoNuSegGenerator, MoNuSegPreparedGenerator
- from src.utils.direction_transform import get_transforms_list
- from src.utils.loss import *
-
- import glob
- import numpy as np
- import time
-
- import mindspore
- import mindspore.nn as nn
- import mindspore.ops.functional as F
- import mindspore.ops.operations as P
- import mindspore.ops as ops
- from mindspore import dtype as mstype
- from mindspore import Tensor
- from mindspore import save_checkpoint, load_checkpoint, load_param_into_net
- from mindspore.common.initializer import One, Normal
-
- from mindspore import context
-
- # context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
- # context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
- context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
- print('设置运行模式Ascend')
-
-
- class HanNetWithLoss(nn.Cell):
- def __init__(self, hannet):
- super(HanNetWithLoss, self).__init__(auto_prefix=False)
- self._hannet = hannet
-
- def construct(self, _data, epoch, i):
- _input = _data['input']
- _target0 = _data['target0']
- _weight_map = _data['weight_map']
- # print(epoch, i, _input.min(), _input.max())
- _output = self._hannet(_input)
-
- boundary = 2
-
- # target | target0
- target = _target0
- if target.max() == 255:
- target = target // int(255 / 2)
- if target.dim() == 4:
- target = target.squeeze(1)
-
- target1 = _target0.asnumpy()
- num_classes = 3
- target_temp = np.zeros((_target0.shape[0], num_classes, _target0.shape[-2], _target0.shape[-1]), dtype=np.uint8)
- color_number = np.unique(target1)
-
- for j in range(_target0.shape[0]):
- target_temp[j, 0, :, :][target1[j, 0, :, :] == color_number[0]] = 1
- try:
- target_temp[j, 1, :, :][target1[j, 0, :, :] == color_number[1]] = 1
- if (num_classes == 3):
- target_temp[j, 2, :, :][target1[j, 0, :, :] == color_number[2]] = 1
- else:
- target_temp[j, 0, :, :][target1[j, 0, :, :] != color_number[0]] = 1
- except:
- if (num_classes != 1):
- print('train IndexError: index 1 is out of bounds for axis 0 with size 1')
-
- _target0 = Tensor(target_temp, dtype=mstype.float32)
- #
- # # weight map
- _weight_map = _weight_map / 20
- if _weight_map.dim() == 4:
- _weight_map = _weight_map.squeeze(1)
- weight_map_var = _weight_map
-
- _loss_ce, log_prob_maps = loss_ce(_output, target, weight_map_var)
- _loss_dice = loss_dice(_output, _target0)
-
- ## metric
- pred = np.argmax(log_prob_maps.asnumpy(), axis=1)
- metrics = accuracy_pixel_level(pred, target.asnumpy())
-
- # print(metrics)
- pixel_accu, pixel_iou, pixel_recall, pixel_precision, pixel_F1, _ = metrics
-
- loss = _loss_ce + _loss_dice
-
- if i % 5 == 0:
- print(
- '''epoch:%3d, iter:%3d, loss=%.4f, l_ce=%.4f, l_d_dice=%.4f, pixel_accu=%.4f, p_iou=%.4f, p_recall=%.4f, p_precision=%.4f, p_F1=%.4f'''
- % (epoch, i, float(loss.asnumpy()), float(_loss_ce.asnumpy()), float(_loss_dice.asnumpy()),
- pixel_accu, pixel_iou, pixel_recall, pixel_precision, pixel_F1))
-
- return loss
-
-
- def run_train(epoch_start=0, checkpoint_path=''):
- print('start hannet train:')
-
- # ====== dataset ======
- train_data = MoNuSegPreparedGenerator(data_dir='data_prepare/train')
- dataset = ds.GeneratorDataset(train_data, ['input', 'weight_map', 'target0', 'target_point0', 'target_direction0'])
- dataset = dataset.batch(4)
-
- # ====== model ======
- print('modeling...')
- _hannet = HanNet(in_channels=3, output_channels=3)
-
- # if train from checkpoint file
- if len(checkpoint_path) > 0:
- print('load checkpoint:', checkpoint_path)
- param_dict = load_checkpoint(checkpoint_path)
- load_param_into_net(_hannet, param_dict)
-
- _hannet_with_loss = HanNetWithLoss(_hannet)
-
- # optimizer 用 adam
- optim = nn.Adam(_hannet.trainable_params(), learning_rate=5e-4, beta1=0.9, beta2=0.999, weight_decay=5e-4)
-
- train_net = nn.TrainOneStepCell(_hannet_with_loss, optim)
- train_net.set_train()
-
- loss = 0
- epochs = 501
- best_epoch_loss = 10
- for epoch in range(epoch_start, epochs):
- epoch_loss = 0.0
- print('epoch: %03d, loss:%.5f' % (epoch, (epoch_loss / dataset.get_dataset_size())))
- for i, data in enumerate(dataset.create_dict_iterator()):
- if i == 0:
- time.sleep(5)
- loss = train_net(data, epoch, i)
- epoch_loss += float(loss.asnumpy())
- epoch_loss = epoch_loss / dataset.get_dataset_size()
- print('epoch loss:%.5f' % (epoch_loss))
- if epoch % 10 == 0:
- checkpoint_name = 'checkpoint/train_save/hannet_epoch_%03d_loss_%.2f.ckpt' % (
- epoch, epoch_loss)
- print('save checkpoint:', checkpoint_name)
- save_checkpoint(_hannet, checkpoint_name)
- if epoch_loss < best_epoch_loss:
- best_epoch_loss = epoch_loss
- checkpoint_name = 'checkpoint/train_save/hannet_best.ckpt'
- print('save best checkpoint:', checkpoint_name)
- save_checkpoint(_hannet, checkpoint_name)
- return loss
-
-
- if __name__ == "__main__":
- epoch_start = 0
- checkpoint_path = ''
- run_train(epoch_start, checkpoint_path)
|