|
- import glob
- import os
- import time
- import torch
- import sys
- import argparse
- import numpy as np
- from numpy import *
- import nibabel as nib
- import SimpleITK as sitk
- import scipy.ndimage as ndimage
-
- from monai.transforms import (
- Invertd,
- EnsureChannelFirstd,
- Compose,
- LoadImaged,
- Orientationd,
- ScaleIntensityRanged,
- Spacingd,
- EnsureTyped,
- )
- from monai.inferers import sliding_window_inference
-
- sys.path.insert(0, './network')
- from unet_flare import BasicUNet
-
- torch.backends.cudnn.enabled = True
- torch.backends.cudnn.benchmark = True
-
- def load_model(args, device, modelo_path, modelc_path):
- if args.model == 'unet_c16':
- model = BasicUNet(
- spatial_dims = 2,
- in_channels = 1,
- out_channels = args.num_classes,
- features = (16, 32, 64, 128, 256, 128),
- norm = "batch", # Norm.BATCH,
- )
- else:
- model = BasicUNet(
- spatial_dims = 2,
- in_channels = 1,
- out_channels = args.num_classes,
- features = (32, 64, 128, 256, 512, 256),
- norm = "batch", # Norm.BATCH,
- )
- model = torch.nn.DataParallel(model).to(device)
-
- checkpoint_o = torch.load(modelo_path)
- model.load_state_dict(checkpoint_o['model_state_dict'])
-
- if modelc_path != '':
- checkpoint_c = torch.load(modelc_path)
- for key, value in model.named_parameters():
- k = key.split('.')[1]
- if k in ['upcat_2_1', 'upcat_1_1', 'final_conv_1']: # weights for cancer
- value = checkpoint_c['model_state_dict'][key]
-
- model.eval()
- return model
-
-
- def inference(args, test_dir, model, output_dir):
- test_transform = Compose([
- LoadImaged(keys=["image"]),
- EnsureChannelFirstd(keys=["image"]),
- Orientationd(keys=["image"], axcodes='RAS'),
- ScaleIntensityRanged(
- keys=["image"], a_min=args.a_min, a_max=args.a_max,
- b_min=0.0, b_max=1.0, clip=True,
- ),
- Spacingd(keys=["image"], pixdim=(1, 1, 2.5), mode="bilinear"),
- EnsureTyped(keys=["image"]),
- ])
- invert_transform = Invertd(
- keys="pred",
- transform=test_transform,
- orig_keys="image",
- nearest_interp=True,
- )
-
- test_images = sorted(glob.glob(f'{test_dir}/*'))[:]
- test_files = [{"image": image_name} for image_name in test_images]
- print(len(test_files))
-
- with torch.no_grad():
- n = 0
- for test_data in test_files:
- start = time.time()
- test_data = test_transform(test_data)
- test_inputs = test_data["image"].to(device)
- original_affine = test_data["image_meta_dict"]["affine"].numpy()
- img_name = test_data["image_meta_dict"]["filename_or_obj"].split("/")[-1]
- save_name = img_name.replace('_0000','')
- print("Inference on case {}".format(img_name))
-
- test_output = test_inputs
- for s in range(test_inputs.shape[-1]):
- s_input = test_inputs[..., s].unsqueeze(dim=0)
- s_outputs = sliding_window_inference(s_input, args.image_size, args.sw_batch_size, model, overlap=0.5)
- test_output[...,s] = pred_label(s_outputs)
-
- test_data['pred'] = test_inputs
- pred_tensor = invert_transform(test_data)
- prd_arr = pred_tensor['pred'].squeeze().cpu().numpy()
-
- nib.save(
- nib.Nifti1Image(prd_arr.astype(np.uint8), original_affine), os.path.join(output_dir, save_name)
- )
- n += 1
-
- end = time.time()
- print(f'Inference time is {end-start}s for this case.')
- torch.cuda.empty_cache()
- return
-
-
- def save_result(file_id, files, test_output, save_dir):
- try:
- raw_img_p = files[file_id]['image']
- save_name = raw_img_p.split('/')[-1].replace('_0000','')
- raw_img = sitk.ReadImage(raw_img_p)
- raw_img_arr = sitk.GetArrayFromImage(raw_img)
-
- prd_arr1 = test_output.transpose(2,1,0)
- prd_arr2 = np.flip(prd_arr1, axis=(1,2))
- prd_arr = resample_3d(prd_arr2, raw_img_arr)
-
- out = sitk.GetImageFromArray(prd_arr.astype(np.uint8))
- out.SetDirection(raw_img.GetDirection())
- out.SetOrigin(raw_img.GetOrigin())
- out.SetSpacing(raw_img.GetSpacing())
- save_path = f'{save_dir}/{save_name}'
- sitk.WriteImage(out, save_path)
- print(save_path)
- except:
- return
-
-
- def pred_label(pred):
- prd_cancer = torch.softmax(pred[0], 1)
- prd_cancer = torch.argmax(prd_cancer, axis=1)
- prd_organ = torch.softmax(pred[1], 1)
- prd_organ = torch.argmax(prd_organ, axis=1)
- prd = torch.where(prd_cancer==1, 14, prd_organ)
- return prd
-
- def resample_3d(img, target_size):
- imx, imy, imz = img.shape
- tx, ty, tz = target_size
- zoom_ratio = (float(tx) / float(imx), float(ty) / float(imy), float(tz) / float(imz))
- img_resampled = ndimage.zoom(img, zoom_ratio, order=0, prefilter=False)
- return img_resampled
-
-
-
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--a_min", default=-200.0, type=float, help="a_min in ScaleIntensityRanged")
- parser.add_argument("--a_max", default=300.0, type=float, help="a_max in ScaleIntensityRanged")
- parser.add_argument("--input_size", default=224, type=int, help="image size for network input")
- parser.add_argument("--num_classes", default=15, type=int, help="number of segmentation classes, including background")
- parser.add_argument("--sw_batch_size", default=4, type=int)
- parser.add_argument("--model", default='unet_c16', type=str)
- parser.add_argument("--modelo_path", default='', type=str, help="Optional input file, read from stdin if not given")
- parser.add_argument("--modelc_path", default='', type=str, help="Optional input file, read from stdin if not given")
- parser.add_argument("--in_file", default='', help="Optional input file, read from stdin if not given")
- parser.add_argument("--out_file", default='', help="Optional output file, write to stdout if not given")
- args = parser.parse_args()
-
- args.image_size = (args.input_size, args.input_size)
-
- # load the network, assigning it to the selected device
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
- model = load_model(args, device, args.modelo_path, args.modelc_path)
-
- # input
- test_dir = args.in_file
-
- # output
- output_dir = args.out_file
-
- # inference
- inference(args, test_dir, model, output_dir)
|