|
- import numpy as np
- import SimpleITK as sitk
- from PIL import Image
-
- def ious(pred, target, num_class):
- ious = np.zeros((num_class))
- for cls in range(num_class):
- pred_inds = pred == cls
- target_inds = target == cls
- intersection = pred_inds[target_inds].sum()
- union = pred_inds.sum() + target_inds.sum() - intersection
- if union == 0:
- ious[cls]=0
- #ious.append(float('nan')) # if there is no ground truth, do not include in evaluation
- else:
- #ious.append(float(intersection) / max(union, 1))
- ious[cls]=float(intersection) / max(union, 1)
- # print("cls", cls, pred_inds.sum(), target_inds.sum(), intersection, float(intersection) / max(union, 1))
- return np.mean(ious)
-
- def pixel_accs(pred, target):
- correct = (pred == target).sum()
- total = (target == target).sum()
- return correct / total
-
- def dice_coefs(pred, target, num_class):
- smooth = 1e-5
- dice_coefs = np.zeros((num_class))
- for cls in range(num_class):
- pred_inds = pred == cls
- target_inds = target == cls
- intersection = pred_inds[target_inds].sum()
- dice_coefa=(2 * intersection + smooth) / (pred_inds.sum() + target_inds.sum() + smooth)
- dice_coefs[cls]=float(dice_coefa)
- return np.mean(dice_coefs)
-
- def save_result_comparison(epoch, input_np, input_mask, output_np):
- im = np.zeros((128,128,16))
- mask = np.zeros((128,128,16))
- output=np.zeros((128,128,16))
- im[:,:,:] = input_np[0, :, :, :, :]
- mask[:,:,:] = input_mask[0, :, :, :]
- output[:,:,:]=output_np[:,:,:]
- #for i in range(128):
- # for j in range(128):
- # for k in range(16):
- # if output_np[i, j, k] == 0:
- # output[i, j, k] = 0
- # elif output_np[i, j, k] == 1:
- # output[i, j, k] = 255
- #mask = mask.transpose(0, 2, 1)
- #print(mask.shape)
- #im = im.transpose(0, 2, 1)
- #print(im.shape)
- mask = sitk.GetImageFromArray(mask)
- im = sitk.GetImageFromArray(im)
- output = sitk.GetImageFromArray(output)
- file_name = '/home/yqw/seg/check/'
- sitk.WriteImage(mask, file_name+'mask.nii.gz')
- sitk.WriteImage(im, file_name + 'im.nii.gz')
- file='/home/yqw/seg/check1/' + str(epoch) + 'predict.nii.gz'
- sitk.WriteImage(output, file)
- print('successfully save result')
|