|
- import os
- # os.environ['DEVICE_ID'] = "2"
- import time
- import random
- import logging
- import numpy as np
- from mindspore import save_checkpoint
- from models.MultiTaskNet import MultiTaskUNet
- from models.utils import mmseg_acc, accuracy_pixel_level
- from data_folder import DataFolder
- from my_transforms import get_transforms
- from tensorboardX import SummaryWriter
- # from sklearn.metrics import accuracy_score
- import mindspore.nn as nn
- import shutil
- import argparse
- from data_folder import create_dataset
- from models.loss import MultiLoss
- from models.crossentropy import CrossEntropy
- from mindspore import context
-
- # context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
-
- writer = SummaryWriter()
-
-
- def mysave_checkpoint(net, epoch, is_best, save_dir):
- cp_dir = 'checkpoints'
- if not os.path.exists(cp_dir):
- os.mkdir(cp_dir)
- filename = "checkpoints/epoch_" + str(epoch) + ".ckpt"
- save_checkpoint(net, "checkpoints/epoch_" + str(epoch) + ".ckpt")
- if is_best:
- shutil.copyfile(filename, 'checkpoints/final_best.ckpt')
-
-
- def main(args, logger):
- best_acc = 0
- dsets = {}
- data_transforms = {
- 'train': get_transforms({
- 'scale': 240,
- 'horizontal_flip': True,
- 'random_rotation': 90,
- 'random_crop': 240,
- 'to_tensor': 1
- }),
- 'validation': get_transforms({
- 'scale': 240,
- 'to_tensor': 1
- })
- }
-
- for x in ['train', 'validation']:
- img_dir = os.path.join(args.train_img_dir, x)
- target_dir = os.path.join(args.train_label_dir, x)
- dir_list = [img_dir, target_dir]
- dsets[x] = dir_list
- train_loader = create_dataset(dir_list=dsets["train"], post_fix=['.png'], num_channels=[3, 1],
- data_transforms=data_transforms["train"],
- column_names=["input", "target", "category"],
- batch_size=args.batch_size, shuffle=True)
- val_loader = create_dataset(dir_list=dsets["validation"], post_fix=['.png'], num_channels=[3, 1],
- data_transforms=data_transforms["validation"],
- column_names=["input", "target", "category"],
- batch_size=args.batch_size, shuffle=False)
-
- net = MultiTaskUNet(n_channels=3, n_classes=args.num_classes)
- if args.optimizer == 'adam':
- optimizer = nn.Adam(params=net.trainable_params(), learning_rate=args.lr, beta1=0.9, beta2=0.99,
- weight_decay=args.weight_decay)
- if args.optimizer == 'SGD':
- optimizer = nn.SGD(params=net.trainable_params(), learning_rate=args.lr, momentum=0.9,
- weight_decay=args.weight_decay)
- # ----- define optimizer ----- #
- if args.seg_loss == "CE":
- criterion_seg = CrossEntropy(num_classes=8)
- if args.seg_loss == "MSE":
- criterion_seg = nn.MSELoss()
- criterion_cls = CrossEntropy(num_classes=3)
- loss_fn = MultiLoss(seg_loss=criterion_seg, cls_loss=criterion_cls)
- loss_net = nn.WithLossCell(net, loss_fn)
- train_net = nn.TrainOneStepCell(loss_net, optimizer)
- train_net.set_train()
-
- best_cri = -1
- for epoch in range(args.num_epoches):
- # train for one epoch or len(train_loader) iterations
- logger.info('Epoch: [{:d}/{:d}]'.format(epoch + 1, args.num_epoches))
- for index, sample in enumerate(train_loader):
- input, target, category = sample
- target = target.squeeze()
- category = category.squeeze()
- label = (target, category)
- loss = train_net(input, label)
- loss = float(loss.asnumpy())
- segoutput, clsoutput = net(input)
- # calculate the mIoU
- pred = np.argmax(segoutput.asnumpy(), axis=1)
- all_acc, acc, iou = mmseg_acc(pred, target.asnumpy(), num_classes=args.num_classes, ignore_index=0)
- metrics = accuracy_pixel_level(pred, target.asnumpy())
- # calculate the classification precision
- pred_cls = np.argmax(clsoutput.asnumpy(), axis=1)
- # precision = accuracy_score(pred_cls, category)
- if index % 30 == 0:
- logger.info(
- "Epochs: {}/{},Training Loss : {:.2f}, Pixel accu : {:.2f}, mIoU : {:.2f}"
- .format(epoch + 1,
- args.num_epoches,
- loss,
- all_acc,
- metrics[1]))
- niter = epoch * train_loader.get_dataset_size() + index
- writer.add_scalar('Train_Total_Loss', loss, niter)
-
- # writer.add_image('image', segoutput[0], global_step=None, walltime=None, dataformats='CHW')
-
- # images=input
- # features = images#.view( 256, , 3)
- # print(images.shape)
- # label_imgs=(images.permute(0,3,1,2)).unsqueeze(1)
- # writer.add_embedding(features,metadata=target,label_img=images)
- if (epoch + 1) % args.eval_per_epoch == 0:
- eval_results = np.zeros((2,), np.float32)
- max_batch = 0
- t = 0
- p = 0
- for index, sample in enumerate(val_loader):
- input, target, category = sample
- target = target.squeeze()
- segoutput, clsoutput = net(input)
- pred = np.argmax(segoutput.asnumpy(), axis=1)
- all_acc, acc, iou = mmseg_acc(pred, target.asnumpy(), num_classes=args.num_classes, ignore_index=0)
- metrics = accuracy_pixel_level(pred, target.asnumpy())
- eval_results += np.array([all_acc, metrics[1]])
- max_batch = max(index, max_batch)
- pred_cls = np.argmax(clsoutput.asnumpy(), axis=1)
- # p += accuracy_score(pred_cls, category, normalize=False)
- t += input.shape[0]
- eval_results = [value / max_batch for value in eval_results.tolist()]
- logger.info(
- "Eval Results : mAcc = {:.2f}, mIoU = {:.2f}, Classification Precision: {:.2f}".format(eval_results[0],
- eval_results[1],
- p / t))
-
- writer.add_scalar('Val Segmentation mAcc', eval_results[0], epoch)
- writer.add_scalar('Val Segmentation mIoU', eval_results[1], epoch)
- writer.add_scalar('Val Classification Precision', p / t, epoch)
-
- # check if it is the best accuracy
- val_acc = eval_results[0]
- if val_acc > best_acc:
- best_acc = val_acc
- val_iou = eval_results[1]
- cls_precision = p / t
- cri = val_iou + cls_precision
- is_best = cri > best_cri
- # is_best = val_loss < best_loss
- best_cri = max(cri, best_cri)
- # best_loss = min(val_loss, best_loss)
- mysave_checkpoint(net, epoch, is_best, args.model_save_dir)
- print(">>>>>>>>>>>>>Best ACC:", best_acc)
-
-
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description='Process some integers.')
-
- # dataloader
- parser.add_argument('--batch_size', default=4, type=int)
- parser.add_argument('--num_classes', default=8, type=int)
-
- # loss
- parser.add_argument('--seg_loss', default="CE", type=str, choices=["CE", "MSE"])
-
- # optimizer
- parser.add_argument('--optimizer', default="SGD", type=str, choices=["SGD", "Adam"])
- parser.add_argument('--lr', default=1e-4, type=float)
- parser.add_argument('--weight_decay', default=1e-4, type=float)
-
- # training
- parser.add_argument('--num_epoches', default=300, type=int)
- parser.add_argument('--train_img_dir', default="endoscope400/ade20k/images", type=str)
- parser.add_argument('--train_label_dir', default="endoscope400/ade20k/annotations", type=str)
-
- # evaluation
- parser.add_argument('--eval_per_epoch', default=1, type=int)
-
- parser.add_argument('--model_save_dir', default="runs/", type=str)
-
- args = parser.parse_args()
-
- logging.basicConfig(filemode='a',
- format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
- datefmt='%H:%M:%S',
- level=logging.INFO)
- sh = logging.StreamHandler() # 往屏幕上输出
- fh = logging.FileHandler('runs/logs_{}'.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
-
- logging.info("Training and Validation Record.")
- logger = logging.getLogger("FullNet for Endoscope")
- logger.addHandler(sh)
- logger.addHandler(fh)
-
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
-
- main(args, logger)
|