|
- # -*- coding: utf-8 -*-
- from metrics import dice_coefs, pixel_accs
- from utils import AverageMeter
- import torch
- from dataset import MyDataset
- import SimpleITK as sitk
- from unet_3d import unet_3d
- import numpy as np
- from unet_2d import unet_2d
- from torch.utils import data
- from torch.autograd import Variable
- from tqdm import tqdm
-
- def save_result_comparison(epoch, input_np, input_mask, output_np):
- im = np.zeros((128,128,16))
- mask = np.zeros((128,128,16))
- output=np.zeros((128,128,16))
- im[:,:,:] = input_np[0, :, :, :, :]
- mask[:,:,:] = input_mask[0, :, :, :]
- output[:,:,:]=output_np[:,:,:]
- mask = sitk.GetImageFromArray(mask)
- im = sitk.GetImageFromArray(im)
- output = sitk.GetImageFromArray(output)
- file='/home/yqw/seg/20201112/' + str(epoch)
- sitk.WriteImage(mask, file + 'mask.nii.gz')
- sitk.WriteImage(im, file + 'im.nii.gz')
- sitk.WriteImage(output, file + 'predict.nii.gz')
- print('successfully save result')
-
-
- def main():
- #test gpu
- use_gpu = torch.cuda.is_available()
- print(use_gpu)
- num_gpu = list(range(torch.cuda.device_count()))
- print(num_gpu)
- # create model
- model = unet_3d()
- model = model.cuda()
- model.load_state_dict(torch.load('/home/yqw/seg/20201202_unetloss/1202_16.pth'))
- model.eval()
- #load data
- # Data loading code
- whole_set = MyDataset()
- length = len(whole_set)
- train_size = 230
- train_size, test_size = train_size, length - train_size
- train_set, test_set = data.random_split(whole_set, [train_size, test_size])
- val_loader = data.DataLoader(test_set, batch_size=1, num_workers=0, shuffle=False)
-
- #if val_args.mode == "GetPicture":
- if 1>0:
- """
- Generate result pictures
- """
- dice_coef = AverageMeter()
- pixel_acc = AverageMeter()
- with torch.no_grad():
- for iter, batch in tqdm(enumerate(val_loader), total=len(val_loader)):
- if use_gpu:
- inputs = Variable(batch['X'].cuda())
- else:
- inputs = Variable(batch['X'])
- # compute output
- output = model(inputs)
- outputs = output.data.cpu().numpy()
- N, _, h, w, d = outputs.shape
- pred = outputs.transpose(0, 2, 3, 4, 1).reshape(-1, 2).argmax(axis=1).reshape(N, h, w, d)
- mask = batch['l'].cpu().numpy().reshape(N, h, w, d)
- image = pred[0, :, :, :]
- save_result_comparison(iter, batch['X'], mask, image)
- dice_coefm = dice_coefs(pred, mask, 2)
- pixel_accm = pixel_accs(pred, mask)
- dice_coef.update(dice_coefm, inputs.size(0))
- pixel_acc.update(pixel_accm, inputs.size(0))
- print('dice_coef', dice_coef.avg,'pixel_acc', pixel_acc.avg)
- print("Done!")
-
-
- if __name__ == '__main__':
- main()
|