|
- #coding:utf8
- import models
- from sets import *
- import torch as t
- from tqdm import tqdm
- import numpy
- import time
- import os
- import nibabel as nib
-
- check_ing_path = './check/'
-
- check_list = os.listdir(check_ing_path)
-
- #read_list = os.listdir('/userhome/ye/2020_02/NET70/NET07/04/check/dice00/')
-
- #model_feature = getattr(models, 'unet_3d')()
- #model_feature.cuda()
- #model_feature.load_state_dict(t.load('/userhome/GUOXUTAO/2020_00/NET07/checkpoints/pthh/0.9927196598052979_4444_119.28109375_0.0001_4_0319_15:25:21.pth'))
-
- for index,checkname in enumerate(check_list):
-
- print(index,checkname)
-
- #if checkname not in read_list:
- if 1 > 0:
-
- model = getattr(models, 'unet_3d')()
- model.eval()
- model.load_state_dict(t.load(check_ing_path+checkname))
- model.eval()
-
- if opt.use_gpu: model.cuda()
-
- if 1 > 0:
-
- testpath = './testdata/nii/'
- folderlist = os.listdir(testpath)
-
- WT_dice = []
-
-
- for index,fodername in enumerate(folderlist):
- print(index,fodername)
- data = nib.load(os.path.join(testpath,fodername)).get_data().astype(float)
- print(data.shape)
- vector = data[0:1,:,:,:]
- tru = data[1,:,:,:]
-
- prob = np.zeros((2,data.shape[1],data.shape[2],data.shape[3]))
-
- g = 0
- s0 = 64
- s1 = 8
- ss = 128
- sss = 16
- for i in range(50):
- for ii in range(50):
- for iii in range(50):
- if g+s0*i+ss < data.shape[1]-g:
- if g+s0*ii+ss < data.shape[2]-g:
- if g+s1*iii+sss < data.shape[3]-g:
- img_out = vector[:,g+s0*i:g+s0*i+ss,g+s0*ii:g+s0*ii+ss,g+s1*iii:g+s1*iii+sss]
- img = torch.from_numpy(img_out).unsqueeze(0).float()
- with torch.no_grad():
- input = t.autograd.Variable(img)
- if True: input = input.cuda()
-
- #down_1 = model_feature(input)
- #print(down_1.shape)
- score = model(input)
- score = torch.nn.Softmax(dim=1)(score).squeeze().detach().cpu().numpy()
-
- prob[:,g+s0*i:g+s0*i+ss,g+s0*ii:g+s0*ii+ss,g+s1*iii:g+s1*iii+sss] = prob[:,g+s0*i:g+s0*i+ss,g+s0*ii:g+s0*ii+ss,g+s1*iii:g+s1*iii+sss] + score
-
-
- label = np.argmax((prob).astype(float),axis=0)
- pre = label
-
- new_image = nib.Nifti1Image(pre, np.eye(4))
- new_image.set_data_dtype(np.dtype('Float64'))
-
- nib.save(new_image, './pre_out/'+fodername)
-
- #print(np.sum(tru==1),np.sum(tru==2),np.sum(tru==4))
- ###################################
- ###################################
- ###################################
- ### WT 1 2 4
- ### TC 1 4
- ### ET 4
- #WT_pre = np.zeros((data.shape[1],data.shape[2],data.shape[3]))
- #WT_tru = np.zeros((data.shape[1],data.shape[2],data.shape[3]))
- #TC_pre = np.zeros((data.shape[1],data.shape[2],data.shape[3]))
- #TC_tru = np.zeros((data.shape[1],data.shape[2],data.shape[3]))
- #ET_pre = np.zeros((data.shape[1],data.shape[2],data.shape[3]))
- #ET_tru = np.zeros((data.shape[1],data.shape[2],data.shape[3]))
-
- #WT_pre[pre>0] = 1
- #WT_tru[tru>0] = 1
-
- #TC_pre[pre==1] = 1
- #TC_tru[tru==1] = 1
- #TC_pre[pre==4] = 1
- #TC_tru[tru==4] = 1
-
- #ET_pre[pre==4] = 1
- #ET_tru[tru==4] = 1
-
- preg = pre
- trug = tru
-
- pre = np.zeros(preg.shape)
- tru = np.zeros(trug.shape)
- pre[preg==1] = 1
- tru[trug==1] = 1
- a1 = np.sum(pre==1)
- a2 = np.sum(tru==1)
- a3 = np.sum(np.multiply(pre,tru)==1)
- #print(a1,a2,a3)
- if a1+a2 > 0:
- WT_Dice = (2.0*a3)/(a1 + a2)
- WT_dice.append(WT_Dice)
-
-
-
- print(WT_Dice)
-
- #np.save('userhome/GUOXUTAO/2019_01/NET04/ww.npy',WT_dice)
- #np.save('userhome/GUOXUTAO/2019_01/NET04/tt.npy',TC_dice)
- #np.save('userhome/GUOXUTAO/2019_01/NET04/ee.npy',ET_dice)
-
- ### mean
- mean_WT_dice = np.mean(WT_dice)
-
-
- print('mean ', 'WT:', mean_WT_dice)
-
- ### std
- std_WT_dice = np.std(WT_dice)
-
- print('std ', 'WT:', std_WT_dice)
-
-
- #os.makedirs('/userhome/ye/2020_02/NET70/NET07/04/check/dice00/'+checkname+'/')
- #savee = []
- #savee.append(mean_WT_dice)
-
- #np.save('/userhome/ye/2020_02/NET70/NET07/04/check/dice00/'+checkname+'/dice.npy',savee)
- #break
- while 1 > 0:
- a = 1
|