|
- import argparse
- import sys
- import numpy as np
- import torch
- import cv2
- import math
- from skimage.segmentation import chan_vese
- import os
- from torchvision import transforms
- from utils.read_data import Covid_19_Dataset
- from torch.utils.data import DataLoader
- from networks.unet import UNet
- from collections import OrderedDict
- from utils.util import weight_to_cpu
- from matplotlib import pyplot as plt
- import numpy as np
- from PIL import Image
- from sklearn.metrics import precision_recall_curve, auc
- from tqdm import tqdm
- from networks.vision_transformer import SwinUnet
- from config import get_config
- from networks.swin_unetr import SwinUNETR
-
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
-
- def parse_args():
- parser = argparse.ArgumentParser(description='Segment Covid-19 lesion area')
- parser.add_argument('--test_set_dir', '-td', type=str, default='./data_1/gan/eval/covid-cell',
- help='folder to load test set')
- parser.add_argument('--gt_dir', '-gtd', type=str, default='./data_1/gan/eval/covid-cell/gt',
- help='folder of ground truth')
- parser.add_argument('--save_dir', '-sd', type=str, default='./test_result/20230730_1/epoch=65/covid-cell/threshold=20',
- help='folder to save result')
- parser.add_argument('--batch_size', '-bs', type=int, default=4, help='batch size per gpu')
- parser.add_argument('--power', '-k', type=int, default=2, help='power of weight')
- parser.add_argument('--input_channel', type=int, default=1, help='input channel of model')
- parser.add_argument('--pretrain_swin_unet_path', type=str, default='./check_points/20230405_1(epoch=65)/g.pkl',
- help='pretrained swin-unet')
- parser.add_argument("--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, nargs='+', )
- parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
- parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
- help='no: no cache, '
- 'full: cache all data, '
- 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
- parser.add_argument('--resume', help='resume from checkpoint')
- parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
- parser.add_argument('--use-checkpoint', action='store_true',
- help="whether to use gradient checkpointing to save memory")
- parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
- help='mixed precision opt level, if O0, no amp is used')
- parser.add_argument('--tag', help='tag of experiment')
- parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
- parser.add_argument('--throughput', action='store_true', help='Test throughput only')
- parser.add_argument('--img_size', type=int,
- default=224, help='input patch size of network input')
- parser.add_argument('--cfg', type=str, default='./configs/swin_tiny_patch4_window7_224_lite.yaml', metavar="FILE",
- help='path to config file', )
- return parser.parse_args()
-
- def mat_math(intput, str):
- output = intput
- for i in range(intput.shape[0]):
- for j in range(intput.shape[1]):
- if str == "atan":
- output[i, j] = math.atan(intput[i, j])
- if str == "sqrt":
- output[i, j] = math.sqrt(intput[i, j])
- return output
-
- # CV函数
- def CV(LSF, img, mu, nu, epison, step):
- Drc = (epison / math.pi) / (epison * epison + LSF * LSF)
- Hea = 0.5 * (1 + (2 / math.pi) * mat_math(LSF / epison, "atan"))
- Iy, Ix = np.gradient(LSF)
- s = mat_math(Ix * Ix + Iy * Iy, "sqrt")
- Nx = Ix / (s + 0.000001)
- Ny = Iy / (s + 0.000001)
- Mxx, Nxx = np.gradient(Nx)
- Nyy, Myy = np.gradient(Ny)
- cur = Nxx + Nyy
- Length = nu * Drc * cur
-
- Lap = cv2.Laplacian(LSF, -1)
- Penalty = mu * (Lap - cur)
-
- s1 = Hea * img
- s2 = (1 - Hea) * img
- s3 = 1 - Hea
- C1 = s1.sum() / Hea.sum()
- C2 = s2.sum() / s3.sum()
- CVterm = Drc * (-1 * (img - C1) * (img - C1) + 1 * (img - C2) * (img - C2))
-
- LSF = LSF + step * (Length + Penalty + CVterm)
- # plt.imshow(s, cmap ='gray'),plt.show()
- return LSF
-
- class Predict:
- def __init__(self, args):
- self.mean, self.std = 0.5, 0.5
- self.args = args
- self.dataloader = self.get_dataloader()
- self.auto_encoder = self.get_swin_unet(args)
- self.prediction = []
- self.label = []
- self.segmentation_threshold = 20
-
- def get_dataloader(self):
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize(self.mean, self.std)
- ])
- dataset = Covid_19_Dataset(data_dir=self.args.test_set_dir,
- transform=transform)
- data_loader = DataLoader(dataset,
- batch_size=self.args.batch_size,
- shuffle=True,
- num_workers=0,
- drop_last=True,
- pin_memory=False)
- return data_loader
-
- def get_swin_unet(self, args):
- config = get_config(args)
- swin_unet = SwinUnet(config, img_size=224, num_classes=1).cuda()
- print('loading the parameters of generator......')
- swin_unet.load_state_dict(weight_to_cpu(self.args.pretrain_swin_unet_path))
- print('success loading!')
- return swin_unet
-
-
- def save_single_image(self, saved_dir, names, inputs):
- """
- save swin-unet output as a form of image
- """
- if not os.path.exists(saved_dir):
- os.makedirs(saved_dir)
- with torch.no_grad():
- outputs, gdps = self.auto_encoder(inputs)
- print(outputs.shape, gdps.shape)
- total_dice = 0
- for i in range(inputs.shape[0]):
- input = inputs[i]
- output = outputs[i]
- gdp = gdps[i]
- name = names[i]
- left = self.restore(input)
- right = self.restore(output)
- lesion = self.restore(gdp)
-
-
- shotname, extension = os.path.splitext(name)
-
- # gt.resize must use Nearest neighbor interpolation
- gt = Image.open(os.path.join(self.args.gt_dir, shotname + '_mask' + extension)).resize((224,224), Image.NEAREST)
- gt = np.array(gt)
- gt = np.where(gt == 0, 0, 255)
- diff = np.where(left > right, left - right, right - left).clip(0, 255).astype(np.uint8)
- # 初始水平集函数
-
- IniLSF = np.ones((diff.shape[0], diff.shape[1]), diff.dtype)
- IniLSF[1:223, 1:223] = -1
- IniLSF = -IniLSF
- #画初始轮廓
- #plt.figure(1)
- #plt.imshow(diff)
- #plt.xticks([])
- #plt.yticks([]) # to hide tick values on X and Y axis
- #plt.imshow(IniLSF)
- #plt.contour(IniLSF, [0], color='b', linewidth=2) # 画LSF=0处的等高线
- #plt.draw(), plt.show(block=False)
-
- mu = 1
- nu = 0.003 * 255 * 255
- num = 100
- epison = 1
- step = 0.1
- LSF = IniLSF
- for i in range(1, num):
- LSF = CV(LSF, diff, mu, nu, epison, step) # 迭代
- if i % 1 == 0: # 显示分割轮廓
- plt.imshow(diff), plt.xticks([]), plt.yticks([])
- plt.contour(LSF, [0], colors='r', linewidths=2)
- plt.draw(), plt.show(block=False), plt.pause(0.01)
- # plt.imshow(LSF, cmap='gray')
- #diff = lesion.clip(0, 255).astype(np.uint8)
- #diff_1 = np.where(diff >= self.segmentation_threshold, 255, 0).astype(np.uint8)
- self.save_prediction_label(diff, gt)
- dice_score = self.dice_coeff(diff_1 / 255, gt / 255)
- # print(dice_score)
- total_dice += dice_score
- plt.figure(num='swin_unet result', figsize=(8, 8))
-
- plt.subplot(2, 3, 1)
- plt.title('source image')
- plt.imshow(left, cmap='gray')
- plt.axis('off')
-
- plt.subplot(2, 3, 2)
- plt.title('swin_unet output1')
- plt.imshow(right, cmap='gray')
- plt.axis('off')
-
- plt.subplot(2, 3, 3)
- plt.title('swin_unet output2')
- plt.imshow(lesion, cmap='gray')
- plt.axis('off')
-
- plt.subplot(2, 3, 4)
- plt.imshow(lesion, cmap='jet')
- plt.colorbar(orientation='horizontal')
- #plt.title('difference in heatmap')
- plt.title('lesion in heatmap')
- plt.axis('off')
-
- plt.subplot(2, 3, 5)
- plt.imshow(gt, cmap='gray')
- plt.title('GT')
- plt.axis('off')
-
- plt.subplot(2, 3, 6)
- plt.imshow(diff_1, cmap='gray')
- plt.title('segmentation result ' + str(round(dice_score, 3)))
- plt.axis('off')
-
- plt.tight_layout()
- plt.savefig(os.path.join(saved_dir, name))
- plt.close()
- return total_dice
-
- def restore(self, x):
- x = torch.squeeze(x)
- x = x.data.cpu()
- for t, m, s in zip([x], [self.mean], [self.std]):
- t.mul_(s).add_(m)
- # transform Tensor to numpy
- x = x.numpy()
- # x = np.transpose(x, (1, 2, 0))
- x = np.clip(x * 255, 0, 255).astype(np.uint8)
- return x
-
- def dice_score(self, X, Y):
- assert X.shape == Y.shape
- return self.dice_coeff(X, Y)
-
- def dice_coeff(self, pred, target):
- smooth = 1.
- assert pred.shape == target.shape
- # num = pred.size(0)
- m1 = pred # Flatten
- m2 = target # Flatten
- intersection = (m1 * m2).sum()
-
- return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
-
- def save_prediction_label(self, diff, gt):
- predict = (diff / 255.).reshape(-1)
- label = (gt / 255.).reshape(-1)
- self.prediction.extend(list(predict))
- self.label.extend(list(label))
-
- def show_PR_curve(self):
- print(len(self.label), len(self.prediction))
- # print(self.prediction)
- precision, recall, thresholds = precision_recall_curve(self.label, self.prediction)
- pr_auc = auc(recall, precision)
- print('auc:', pr_auc)
- plt.plot(recall, precision, color='blue')
- plt.title('Precision/Recall Curve')
- plt.xlabel('Recall')
- plt.ylabel('Precision')
- plt.legend(loc="upper right", labels=['Ours(AUC={})'.format(round(pr_auc, 3))])
- plt.savefig(os.path.join(self.args.save_dir, 'PR_curve.png'))
- plt.show()
-
- np.savez(os.path.join(self.args.save_dir, 'PR_data.npz'), precision=precision, recall=recall)
-
- def predictor(self):
- self.auto_encoder.eval()
-
- save_dir = self.args.save_dir
- print(save_dir)
- total_num = 0
- total_dice = 0
-
- for idx, item in tqdm(enumerate(self.dataloader), desc='processing result...', total=len(self.dataloader),
- ncols=100):
- inputs, names = item['image'], item['name']
- # print(inputs.shape, names)
- total_num += len(names)
- batch_dice = self.save_single_image(save_dir, names, inputs)
- total_dice += batch_dice
- print('average dice:', total_dice / total_num)
-
-
- if __name__ == '__main__':
- args = parse_args()
- predict = Predict(args)
- predict.predictor()
- predict.show_PR_curve()
-
-
-
|