|
- #coding:utf8
- import models
- from config import *
- import torch as t
- from tqdm import tqdm
- import numpy
- import time
- import os
- import nibabel as nib
-
- check_ing_path = './check/'
-
- check_list = os.listdir(check_ing_path)
-
- #model_feature = getattr(models, 'unet_3dd')()
- #model_feature.cuda()
- #model_feature.load_state_dict(t.load('/userhome/GUOXUTAO/2019_01/NET30/02/checkpoints/stage1/0.9912015448676217_4444_144.1538888888889_0.0001_4_0519_05:27:15.pth'))
-
- for index,checkname in enumerate(check_list):
- checkname = '0.996638556321462_4444_220.29557291666666_0.0001_4_0826_09:03:35.pth'
- print(index,checkname)
-
- #if checkname not in read_list:
- if 1 > 0:
-
- model = getattr(models, 'unet_3d')()
- model.eval()
- model.load_state_dict(t.load(check_ing_path+checkname))
- model.eval()
-
- if opt.use_gpu: model.cuda()
-
- if 1 > 0:
-
- testpath = './testdata/nii/'
- folderlist = os.listdir(testpath)
-
- WT_dice = []
- WW_dice = []
- TT_dice = []
- EE_dice = []
-
- for fodername in folderlist:
- print(fodername)
- data = nib.load(os.path.join(testpath,fodername)).get_data().astype(float)
- vector = data[0:1,:,:,:]
- tru = data[1,:,:,:]
-
- pro = np.zeros((5,data.shape[1],data.shape[2],data.shape[3]))
- flag = np.zeros((5,data.shape[1],data.shape[2],data.shape[3]))
-
- s0 = 128
- s1 = 48
- ss = 256
- sss = 96
- for i in range(50):
- for ii in range(50):
- for iii in range(50):
- if s0*i+ss < data.shape[1]:
- if s0*ii+ss < data.shape[2]:
- if s1*iii+sss < data.shape[3]:
-
- img_out = vector[:,s0*i:s0*i+ss,s0*ii:s0*ii+ss,s1*iii:s1*iii+sss].astype(float)
- img = torch.from_numpy(img_out).unsqueeze(0).float()
- with torch.no_grad():
- input = t.autograd.Variable(img)
- if opt.use_gpu: input = input.cuda()
-
- #down_1 = model_feature(input)
-
- score = model(input)
- score = nn.Softmax(dim=1)(score).squeeze().detach().cpu().numpy()
-
- pro[:,s0*i:s0*i+ss,s0*ii:s0*ii+ss,s1*iii:s1*iii+sss] = pro[:,s0*i:s0*i+ss,s0*ii:s0*ii+ss,s1*iii:s1*iii+sss]+score
- flag[:,s0*i:s0*i+ss,s0*ii:s0*ii+ss,s1*iii:s1*iii+sss] = flag[:,s0*i:s0*i+ss,s0*ii:s0*ii+ss,s1*iii:s1*iii+sss]+1
-
- flag[flag==0] = 1
- label = np.argmax((pro/flag).astype(float),axis=0)
- pre = label
-
- new_image = nib.Nifti1Image(pre, np.eye(4))
- new_image.set_data_dtype(np.dtype('Float64'))
-
- nib.save(new_image, './pre_out/'+fodername)
-
-
- #print(np.sum(tru==1),np.sum(tru==2),np.sum(tru==4))
- ###################################
- ###################################
- ###################################
-
- a1 = np.sum(pre==1.0)
- a2 = np.sum(tru==1.0)
- a3 = np.sum(np.multiply(pre,tru)==1)
- #print(a1,a2,a3)
- if a1+a2 > 0:
- WT_Dice = (2.0*a3)/(a1 + a2)
- WT_dice.append(WT_Dice)
-
- a1 = np.sum(pre==2.0)
- a2 = np.sum(tru==2.0)
- a3 = np.sum(np.multiply(pre,tru)==4)
- #print(a1,a2,a3)
- if a1+a2 > 0:
- WW_Dice = (2.0*a3)/(a1 + a2)
- WW_dice.append(WW_Dice)
-
- a1 = np.sum(pre==3.0)
- a2 = np.sum(tru==3.0)
- a3 = np.sum(np.multiply(pre,tru)==9)
- #print(a1,a2,a3)
- if a1+a2 > 0:
- TT_Dice = (2.0*a3)/(a1 + a2)
- TT_dice.append(TT_Dice)
-
- a1 = np.sum(pre==4.0)
- a2 = np.sum(tru==4.0)
- a3 = np.sum(np.multiply(pre,tru)==16)
- #print(a1,a2,a3)
- if a1+a2 > 0:
- EE_Dice = (2.0*a3)/(a1 + a2)
- EE_dice.append(EE_Dice)
- print(WT_Dice,WW_Dice,TT_Dice,EE_Dice)
-
- ### mean
- mean_WT_dice = np.mean(WT_dice)
- mean_WW_dice = np.mean(WW_dice)
- mean_TT_dice = np.mean(TT_dice)
- mean_EE_dice = np.mean(EE_dice)
-
- print('mean ', 'WT:', mean_WT_dice, 'WW:', mean_WW_dice, 'TT:', mean_TT_dice, 'EE:', mean_TT_dice)
-
- ### std
- #std_WT_dice = np.std(WT_dice)
- #std_WW_dice = np.std(WW_dice)
- #std_TT_dice = np.std(TT_dice)
-
- #print('std ', 'WT:', std_WT_dice, 'WW:', std_WW_dice, 'TT:', std_TT_dice)
-
- #os.makedirs('/userhome/GUOXUTAO/2019_01/NET30/02/dice/'+checkname+'/')
- #savee = []
- #savee.append(mean_WT_dice)
- #savee.append(mean_WW_dice)
- #savee.append(mean_TT_dice)
- #np.save('/userhome/GUOXUTAO/2019_01/NET30/02/dice/'+checkname+'/dice.npy',savee)
- break
-
-
- print('over!')
- #while 1 > 0:
- # a = 1
|