|
- #!/usr/bin/env python
- # -*- coding: UTF-8 -*-
- import h5py
- import scipy.io as io
- import PIL.Image as Image
- from mindspore.train.serialization import load_param_into_net, load_checkpoint
- import mindspore.dataset.vision.py_transforms as py_vision
- from model import UCSRNet
- from mindspore import Tensor
- import numpy as np
- import matplotlib.pyplot as plt
- import cv2
- import shutil
- from skimage.feature import peak_local_max
- import time
- import sys
- import os
- import numpy as np
- from PIL import ImageFile
- import argparse
-
- ImageFile.LOAD_TRUNCATED_IMAGES = True
-
- sys.path.append(os.path.join(os.getcwd(), "semantic_segmentation_zoo"))
- sys.path.append(os.path.join(os.getcwd(), "semantic_segmentation_zoo/core"))
- sys.path.append(os.path.join(os.getcwd(), "semantic_segmentation_zoo/core/models"))
- sys.path.append(os.path.join(os.getcwd(), "other_models"))
- os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
-
-
- def set_parser(parser):
- parser.add_argument('--input_images_dir', metavar='?', default="BCData/images/test", help='input_images_dir')
- parser.add_argument('--results_save_path', metavar='?', default="pred_results", help='results_save_path')
- parser.add_argument('--train_json', metavar='TRAIN', default=None, help='path to train json')
- parser.add_argument('--checkpoint_save_path', default="results/checkpoints/final_best.ckpt", type=str,
- help='checkpoint_save_path.')
- parser.add_argument('--test_json', metavar='TEST', default=None, help='path to test json')
- parser.add_argument('--pre_dethead', metavar='PRETRAINED', default=None, type=str,
- help='path to the pretrained model')
- parser.add_argument('--gpu', metavar='GPU', default="0", type=str, help='GPU id to use.')
- parser.add_argument('--dataset_name', metavar='Dataset Name', default="BreCan", type=str, help='Dataset to use.')
- parser.add_argument('--optm', metavar='Optimizer Name', default="adam", type=str, help='Optimizer to use.')
- parser.add_argument('--optm_dethead', metavar='Optimizer Name', default="adam", type=str, help='Optimizer to use.')
- parser.add_argument('--sigma_mul_rate_name', metavar='Optimizer Name', default="sigma_5e-2_", type=str,
- help='Optimizer to use.')
- parser.add_argument('--original_lr', metavar='Original Learning Rate', default=5.5e-6, type=float,
- help='Original Learning Rate.')
- parser.add_argument('--add_direct_count_branch', metavar='Whether to add direct count branch', default=False,
- type=bool, help='Whether to add direct count branch.')
- parser.add_argument('--lamda_loss_cnt', metavar='Lamda Loss Count', default=0.001, type=float,
- help='Lamda Loss Count.')
- add_dcb_cfg = "True-lamlc={}".format(
- parser.parse_args().lamda_loss_cnt) if parser.parse_args().add_direct_count_branch else "False"
- parser.add_argument('--loss_type', metavar='Loss Type', default="mse", type=str, help='Loss type to use.')
- parser.add_argument('--loss_dethead_type', metavar='Loss Type', default="bce", type=str, help='Loss type to use.')
- parser.add_argument('--front_end_model', metavar='Front_end_model Type', default="resnet50", type=str,
- help='Front_end_model Type to use.')
- parser.add_argument('--use_pretrained', metavar='Whether to use pretrained model', default=False, type=bool,
- help='Whether to use pretrained model.')
- if "ShanghaiTech" in parser.parse_args().dataset_name:
- parser.add_argument('--category_nums', metavar='Category Nums', default=1, type=int, help='Category Nums.')
- elif "BCCD" in parser.parse_args().dataset_name:
- parser.add_argument('--category_nums', metavar='Category Nums', default=3, type=int, help='Category Nums.')
- elif "BreCan" in parser.parse_args().dataset_name:
- parser.add_argument('--category_nums', metavar='Category Nums', default=2, type=int, help='Category Nums.')
- elif "UW" in parser.parse_args().dataset_name:
- parser.add_argument('--category_nums', metavar='Category Nums', default=1, type=int, help='Category Nums.')
- elif "PSU" in parser.parse_args().dataset_name:
- parser.add_argument('--category_nums', metavar='Category Nums', default=1, type=int, help='Category Nums.')
- parser.add_argument('--attention', default=False, type=bool)
- if parser.parse_args().dataset_name == "BCCD":
- parser.add_argument('--cate_pos_centers_num_dict', default={"RBC": 20, "WBC": 1, "Platelets": 1})
- parser.add_argument('--cate_neg_centers_num_dict', default={"RBC": 20, "WBC": 1, "Platelets": 1})
- elif parser.parse_args().dataset_name == "BreCan":
- parser.add_argument('--cate_pos_centers_num_dict', default={"DAB": 20, "H": 20})
- parser.add_argument('--cate_neg_centers_num_dict', default={"DAB": 20, "H": 20})
- elif parser.parse_args().dataset_name == "UW":
- parser.add_argument('--cate_pos_centers_num_dict', default={"Nuclei": 20})
- parser.add_argument('--cate_neg_centers_num_dict', default={"Nuclei": 20})
- elif parser.parse_args().dataset_name == "PSU":
- parser.add_argument('--cate_pos_centers_num_dict', default={"Nuclei": 20})
- parser.add_argument('--cate_neg_centers_num_dict', default={"Nuclei": 20})
- parser.add_argument('--center_positions_adjust', default=False, type=bool)
- parser.add_argument('--roi_sigma', default=1., type=float)
- parser.add_argument('--lamda_roi_cls_loss', default=1., type=float)
- parser.add_argument('--lamda_roi_loc_loss', default=100., type=float)
- parser.add_argument('--fixcrop', default=True, type=bool)
- parser.add_argument('--fixsigma', default=True, type=bool)
- parser.add_argument('--use_peak_counts_as_pred_counts', default=True, type=bool)
- parser.add_argument('--use_real_counts_as_gt_counts', default=True, type=bool)
- parser.add_argument('--data_augmentation', default=True, type=bool)
- parser.add_argument('--data_augmentation_dethead', default=False, type=bool)
- assert parser.parse_args().data_augmentation_dethead == False
- parser.add_argument('--batch_normalization_csr', default=False, type=bool)
- parser.add_argument('--batch_normalization_det', default=False, type=bool)
- parser.add_argument('--batch_size', default=8, type=int)
- parser.add_argument('--batch_size_dethead', default=1, type=int)
- if parser.parse_args().batch_size_dethead != 1:
- raise
- parser.add_argument('--dethead_model', default="DetHeadNew", type=str)
- parser.add_argument('--s_lo', default=False, type=bool)
- if parser.parse_args().s_lo == True:
- parser.add_argument('--lamda_s_lo', default=10, type=float)
- else:
- parser.add_argument('--lamda_s_lo', default=0, type=float)
- parser.add_argument('--main_model', default="Final", type=str)
- if parser.parse_args().main_model == "Final":
- parser.add_argument('--io_mode', default="i=o", type=str)
- if parser.parse_args().io_mode == "i=o":
- parser.add_argument('--spatial_scale', default=1, type=int)
- elif parser.parse_args().io_mode == "i>o":
- parser.add_argument('--spatial_scale', default=8, type=int)
- elif parser.parse_args().io_mode == "i<o":
- parser.add_argument('--spatial_scale', default=0.5, type=float)
- if parser.parse_args().spatial_scale == 1:
- parser.add_argument('--find_peaks_min_distance', default=16, type=int)
- elif parser.parse_args().spatial_scale == 8:
- parser.add_argument('--find_peaks_min_distance', default=2, type=int)
- parser.add_argument('--use_one_channel_dm_to_find_localmax', default=True, type=bool)
- parser.add_argument('--lr_adjust', default=False, type=bool)
- parser.add_argument('--dropout', default=True, type=bool)
- parser.add_argument('--dropout_rate', default=0.5, type=float)
- if parser.parse_args().front_end_model == "vgg16":
- parser.add_argument('--l2_regularization', default=False, type=bool)
- elif "res" in parser.parse_args().front_end_model:
- parser.add_argument('--l2_regularization', default=False, type=bool)
- parser.add_argument('--l2_regularization_rate', default=5e-4, type=float)
- parser.add_argument('--relu_type', default="relu", type=str)
- parser.add_argument('--output_type', metavar='Output Processing Type', default="relu", type=str,
- help='Output processing type to use.')
- parser.add_argument('--input_transform', default=True, type=bool)
- parser.add_argument('--tarnorm', default=False, type=bool)
- parser.add_argument('--tar_mul_times', default=50, type=float)
- parser.add_argument('--crop', default=False, type=bool)
- parser.add_argument('--crop_size', default=256, type=int)
- parser.add_argument('--tar_resized', default=False, type=bool)
- if parser.parse_args().io_mode == "i=o":
- assert parser.parse_args().tar_resized == False
- if parser.parse_args().io_mode == "i<o":
- assert parser.parse_args().tar_resized == True
- parser.add_argument('--dilation', default=True, type=bool)
- parser.add_argument('--change_brightness', default=True, type=bool)
- parser.add_argument('--lr_dethead', default=1e-6, type=float)
- parser.add_argument('--checkpoint_save_path_dethead',
- default=str(parser.parse_args().checkpoint_save_path).replace("./checkpoint/",
- "./checkpoint_dethead/").replace(
- "]/", f"][lr_d={parser.parse_args().lr_dethead}]/"), type=str,
- help='checkpoint_save_path_dethead.')
- parser.add_argument('--results_save_path_dethead',
- default=str(parser.parse_args().checkpoint_save_path_dethead).replace("./checkpoint_dethead/",
- "./val_while_train_visualized_results_dethead_on_val_dataset/"),
- type=str, help='results_save_path_dethead.')
- parser.add_argument('--train_results_save_path_dethead',
- default=str(parser.parse_args().checkpoint_save_path_dethead).replace("./checkpoint_dethead/",
- "./val_while_train_visualized_results_dethead_on_train_dataset/"),
- type=str, help='train_results_save_path_dethead.')
- debug_det_epoch = 666666
- parser.add_argument('--pre', '-p', metavar='PRETRAINED', default=None, type=str,
- help='path to the pretrained model')
- parser.add_argument('--lr', default=parser.parse_args().original_lr, type=float)
- parser.add_argument('--momentum', default=0.95, type=float)
- parser.add_argument('--decay_dethead', default=5 * 1e-4, type=float)
- parser.add_argument('--start_epoch', default=0, type=int)
- parser.add_argument('--epochs', default=debug_det_epoch, type=int)
- if parser.parse_args().dataset_name == "BreCan":
- parser.add_argument('--steps', default=[-1, 200, 400, 600], type=list)
- elif parser.parse_args().dataset_name == "UW":
- parser.add_argument('--steps', default=[-1, 200, 400, 600], type=list)
- elif parser.parse_args().dataset_name == "PSU":
- parser.add_argument('--steps', default=[-1, 100, 200, 300], type=list)
- else:
- raise
- parser.add_argument('--scales', default=[1, 0.1, 0.1, 0.1], type=list)
- parser.add_argument('--num_workers', default=4, type=int)
- parser.add_argument('--seed', default=123, type=float)
- parser.add_argument('--print_freq', default=10, type=int)
- parser.add_argument('--print_freq_dethead', default=100, type=int)
- parser.add_argument('--find_peaks_as_det_freq', default=1, type=int)
- parser.add_argument('--save_val_results_while_training', default=True, type=bool)
- if parser.parse_args().dataset_name == "BreCan":
- parser.add_argument('--tar_width', default=640, type=int)
- parser.add_argument('--tar_height', default=640, type=int)
- elif parser.parse_args().dataset_name == "UW":
- parser.add_argument('--tar_width', default=500, type=int)
- parser.add_argument('--tar_height', default=500, type=int)
- elif parser.parse_args().dataset_name == "PSU":
- parser.add_argument('--tar_width', default=612, type=int)
- parser.add_argument('--tar_height', default=452, type=int)
- parser.add_argument('--dethead_epochs', default=1000, type=int)
- parser.add_argument('--use_drop_dethead', default=False, type=bool)
- parser.add_argument('--n_sample', default=80, type=int)
- parser.add_argument('--pos_ratio', default=0.5, type=float)
- parser.add_argument('--pos_distance_thresh', default=8, type=float)
- parser.add_argument('--neg_distance_thresh_hi', default=320, type=float)
- parser.add_argument('--neg_distance_thresh_lo', default=16, type=float)
- if parser.parse_args().dataset_name == "BCCD":
- parser.add_argument('--input_img_height', default=480, type=int)
- parser.add_argument('--input_img_width', default=640, type=int)
- else:
- parser.add_argument('--input_img_height', default=parser.parse_args().tar_height, type=int)
- parser.add_argument('--input_img_width', default=parser.parse_args().tar_width, type=int)
- parser.add_argument('--resize_factor', default=1, type=int)
- parser.add_argument('--win_size', default=3, type=int)
- if parser.parse_args().dataset_name == "BreCan":
- parser.add_argument('--bias_adjust', default=0, type=int)
- elif parser.parse_args().dataset_name == "UW":
- parser.add_argument('--bias_adjust', default=0, type=int)
- elif parser.parse_args().dataset_name == "PSU":
- parser.add_argument('--bias_adjust', default=0, type=int)
- else:
- raise
- parser.add_argument('--load_data_from_mem', default=True, type=bool)
- return parser
-
-
- val_train = False
- if val_train == True:
- use_train_direct_error_to_adjust_val = False
- else:
- use_train_direct_error_to_adjust_val = False
- parser = argparse.ArgumentParser(description='MindSpore Val')
- parser = set_parser(parser)
- args = parser.parse_args()
- local_peak_method = "dingyao"
- if local_peak_method == "dingyao":
- args.resize_factor = 8
- args.win_size = 21
- elif local_peak_method == "peak_local_max":
- args.resize_factor = 8
- args.win_size = "n"
- args.min_distance = 8
- if args.dataset_name == "BreCan":
- disthresh = 10
- elif args.dataset_name == "UW":
- disthresh = 6
- elif args.dataset_name == "PSU":
- disthresh = 10
- elif args.dataset_name == "BM":
- disthresh = 16
- if args.io_mode == "i=o":
- args.resize_factor = 1
- elif args.io_mode == "i<o":
- args.resize_factor = 0.5
- args.win_size = 5
- args.mean_filter_threshold_rate = 1.0
- load_model_and_val = True
- start_epoch = 500
- end_epoch = 500
- step_epoch = 50
- filter_threshold_list = [0.065]
- save_input = False
- save_gt = False
- save_visual_det_results = False
- save_matlab_det_results = False
- dab_thrd = 0.15
- h_thrd = 0.12
-
- totensor = py_vision.ToTensor()
- nor = py_vision.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
-
-
- def transformer_input_image(image):
- image = totensor(image)
- image = nor(image)
- return image
-
-
- def validate(val_img_path, model):
- """
- 需要做的调整
- 1.输出改为json格式:img_json_result={
- [{type:DAB, x:0, :0}, {type}]
- }
- 2.模型改为轮询模式,只对增量的图像列表做处理。 svs切出来的图片格式,wsi_name_x_y_h_width.png,结果对应的是sub imag
- 需要映射回原图
- """
- pos_cor = []
- neg_cor = []
- base_name = os.path.basename(val_img_path)
- val_img = Image.open(val_img_path).convert('RGB')
- # val_img = np.expand_dims(val_img, axis=0)
- val_img = transformer_input_image(val_img)
- # val_img = val_img.asnumpy()
- val_img = Tensor(np.expand_dims(val_img, axis=0))
- outputs = model(val_img)
- pred = outputs.asnumpy()[0]
- temp_dm = pred[0, :, :]
- # peaks = peak_local_max(temp_dm, min_distance=11, exclude_border=False)
- peaks = peak_local_max(temp_dm, min_distance=81, threshold_abs=np.max(temp_dm) / 2., exclude_border=False)
- pos_cor = [[item[1], item[0]] for item in peaks]
-
- temp_dm = pred[1, :, :]
- # peaks = peak_local_max(temp_dm, min_distance=11, exclude_border=False)
- peaks = peak_local_max(temp_dm, min_distance=81, threshold_abs=np.max(temp_dm) / 2., exclude_border=False)
- neg_cor = [[item[1], item[0]] for item in peaks]
- return pos_cor, neg_cor
-
-
- if args.dataset_name == "BCCD":
- args.cates = ["RBC", "WBC", "Platelets"]
- args.cate2index = {"RBC": 0, "WBC": 1, "Platelets": 2}
- elif args.dataset_name == "BreCan":
- args.cates = ["DAB", "H"]
- args.cate2index = {"DAB": 0, "H": 1}
- elif args.dataset_name == "UW":
- args.cates = ["Nuclei"]
- args.cate2index = {"Nuclei": 0}
- elif args.dataset_name == "PSU":
- args.cates = ["Nuclei"]
- args.cate2index = {"Nuclei": 0}
- elif args.dataset_name == "BM":
- args.cates = ["Nuclei"]
- args.cate2index = {"Nuclei": 0}
- model = UCSRNet(classes=2)
- param_dict = load_checkpoint(args.checkpoint_save_path) # 如果要换模型的话,修改模型名字
- load_param_into_net(model, param_dict)
- whole_st_time = time.time()
- precisions_dict = {}
- recalls_dict = {}
- f1_scores_dict = {}
- relative_error_dict = {}
- relative_error_peaks_dict = {}
- adj_list = ["0"]
- for adj in adj_list:
- precisions_dict.setdefault(adj, {"train": {}, "test": {}})
- recalls_dict.setdefault(adj, {"train": {}, "test": {}})
- f1_scores_dict.setdefault(adj, {"train": {}, "test": {}})
- relative_error_dict.setdefault(adj, {"train": [], "test": []})
- relative_error_peaks_dict.setdefault(adj, {"train": [], "test": []})
- for key in precisions_dict.keys():
- for train_or_test_key in ["train", "test"]:
- for cate in args.cates:
- precisions_dict[key][train_or_test_key].setdefault(str(cate), [])
- recalls_dict[key][train_or_test_key].setdefault(str(cate), [])
- f1_scores_dict[key][train_or_test_key].setdefault(str(cate), [])
- if args.dataset_name == "BreCan":
- if "vgg" in args.front_end_model:
- bias_adjust_list = [0]
- elif "res" in args.front_end_model:
- bias_adjust_list = [0]
- elif "alex" in args.front_end_model:
- bias_adjust_list = [0]
- elif args.dataset_name == "UW":
- if "vgg" in args.front_end_model:
- bias_adjust_list = [0]
- elif "res" in args.front_end_model:
- bias_adjust_list = [0]
- elif "alex" in args.front_end_model:
- bias_adjust_list = [0]
- elif args.dataset_name == "PSU":
- if "vgg" in args.front_end_model:
- bias_adjust_list = [0]
- elif "res" in args.front_end_model:
- bias_adjust_list = [0]
- elif "alex" in args.front_end_model:
- bias_adjust_list = [0]
- elif args.dataset_name == "BM":
- if "vgg" in args.front_end_model:
- bias_adjust_list = [0]
- elif "res" in args.front_end_model:
- bias_adjust_list = [0]
- elif "alex" in args.front_end_model:
- bias_adjust_list = [0]
- print("Pred Finished. Begin to cal F1.")
-
-
- def center_point_evaluation(args, pred_centers, gt_centers, cate, all_tp_dict, all_fp_dict, npos_dict):
- nd = pred_centers.shape[0]
- nd_gt = gt_centers.shape[0]
- if nd > 0 and nd_gt > 0:
- pred_centers.astype(float)
- gt_centers.astype(float)
- tp = np.zeros(nd)
- fp = np.zeros(nd)
- det_flag = [False] * nd_gt
- npos_dict[str(cate)] = npos_dict[str(cate)] + nd_gt
- for d in range(nd):
- distances = ((np.array(gt_centers[:, 0]) - np.array(pred_centers[d, :][0])) ** 2 + (
- np.array(gt_centers[:, 1]) - np.array(pred_centers[d, :][1])) ** 2) ** 0.5
- min_dis = np.min(distances)
- jmin = np.argmin(distances)
- if min_dis <= args.dis_thrd:
- if not det_flag[jmin]:
- tp[d] = 1.
- det_flag[jmin] = 1
- else:
- fp[d] = 1.
- else:
- fp[d] = 1.
- all_tp_dict[str(cate)].extend(tp)
- all_fp_dict[str(cate)].extend(fp)
- tp_count = np.sum(tp)
- fP_count = np.sum(fp)
- precision = tp_count / (tp_count + fP_count)
- recall = tp_count / nd_gt
- if precision + recall != 0:
- F1 = 2 * precision * recall / (precision + recall)
- else:
- F1 = 0
- elif nd > 0 and nd_gt == 0:
- tp = np.zeros(nd, dtype=np.float32)
- fp = np.ones(nd, dtype=np.float32)
- all_tp_dict[str(cate)].extend(tp)
- all_fp_dict[str(cate)].extend(fp)
- precision = 0
- recall = 0
- F1 = 0
- elif nd == 0 and nd_gt > 0:
- npos_dict[str(cate)] = npos_dict[str(cate)] + nd_gt
- precision = 0
- recall = 0
- F1 = 0
- elif nd == 0 and nd_gt == 0:
- precision = 1
- recall = 1
- F1 = 1
- return precision, recall, F1
-
-
- def center_point_det_results_initialization(args, result_save_path_root):
- all_tp_dict = {}
- all_fp_dict = {}
- npos_dict = {}
- precision_list_dict = {}
- recall_list_dict = {}
- F1_list_dict = {}
- pred_peak_count_list_dict = {}
- pred_integration_count_list_dict = {}
- gt_count_list_dict = {}
- for cate in args.cate_list:
- all_tp_dict.setdefault(cate, [])
- all_fp_dict.setdefault(cate, [])
- npos_dict.setdefault(cate, 0)
- precision_list_dict.setdefault(cate, [])
- recall_list_dict.setdefault(cate, [])
- F1_list_dict.setdefault(cate, [])
- pred_peak_count_list_dict.setdefault(cate, [])
- pred_integration_count_list_dict.setdefault(cate, [])
- gt_count_list_dict.setdefault(cate, [])
- return all_tp_dict, all_fp_dict, npos_dict, precision_list_dict, recall_list_dict, F1_list_dict, pred_peak_count_list_dict, pred_integration_count_list_dict, gt_count_list_dict
-
-
- def center_point_det_results_reports(args, all_tp_dict, all_fp_dict, npos_dict, precision_list_dict, recall_list_dict,
- F1_list_dict, pred_peak_count_list_dict, pred_integration_count_list_dict,
- gt_count_list_dict):
- mean_F1_dict = {}
- mean_precision_dict = {}
- mean_recall_dict = {}
- mean_F1_TMI2016_dict = {}
- mean_precision_TMI2016_dict = {}
- mean_recall_TMI2016_dict = {}
- mae_peak_count_dict = {}
- mae_integration_count_dict = {}
- mae_peak_ki67_dict = {}
- mae_integration_ki67_dict = {}
- for cate in args.cate_list:
- mean_F1_dict[str(cate)] = np.mean(F1_list_dict[str(cate)])
- mean_precision_dict[str(cate)] = np.mean(precision_list_dict[str(cate)])
- mean_recall_dict[str(cate)] = np.mean(recall_list_dict[str(cate)])
- all_fp = np.asarray(all_fp_dict[str(cate)])
- all_tp = np.asarray(all_tp_dict[str(cate)])
- npos = npos_dict[str(cate)]
- all_fp = np.cumsum(all_fp)
- all_tp = np.cumsum(all_tp)
- rec = all_tp / float(npos)
- prec = all_tp / np.maximum(all_tp + all_fp, np.finfo(np.float64).eps)
- try:
- mean_F1_TMI2016_dict[str(cate)] = 2 * rec[-1] * prec[-1] / (rec[-1] + prec[-1])
- mean_precision_TMI2016_dict[str(cate)] = prec[-1]
- mean_recall_TMI2016_dict[str(cate)] = rec[-1]
- except:
- mean_F1_TMI2016_dict[str(cate)] = 0
- mean_precision_TMI2016_dict[str(cate)] = 0
- mean_recall_TMI2016_dict[str(cate)] = 0
- peak_count_abs_error_list = []
- integration_count_abs_error_list = []
- for i in range(len(gt_count_list_dict[str(cate)])):
- peak_count_abs_error_list.append(
- float(abs(pred_peak_count_list_dict[str(cate)][i] - gt_count_list_dict[str(cate)][i])))
- integration_count_abs_error_list.append(
- float(abs(pred_integration_count_list_dict[str(cate)][i] - gt_count_list_dict[str(cate)][i])))
- mae_peak_count_dict[str(cate)] = np.mean(peak_count_abs_error_list)
- mae_integration_count_dict[str(cate)] = np.mean(integration_count_abs_error_list)
- peak_ki67_abs_error_list = []
- integration_ki67_abs_error_list = []
- for i in range(len(gt_count_list_dict[str(cate)])):
- gt_ki67 = float(gt_count_list_dict["DAB"][i]) / float(gt_count_list_dict["DAB"][i] + gt_count_list_dict["H"][i])
- try:
- pred_peak_ki67 = float(pred_peak_count_list_dict["DAB"][i]) / float(
- pred_peak_count_list_dict["DAB"][i] + pred_peak_count_list_dict["H"][i])
- except:
- pred_peak_ki67 = 0
- try:
- pred_integration_ki67 = float(pred_integration_count_list_dict["DAB"][i]) / float(
- pred_integration_count_list_dict["DAB"][i] + pred_integration_count_list_dict["H"][i])
- except:
- pred_integration_ki67 = 0
- peak_ki67_abs_error_list.append(abs(pred_peak_ki67 - gt_ki67))
- integration_ki67_abs_error_list.append(abs(pred_integration_ki67 - gt_ki67))
- for cate in args.cate_list:
- mae_peak_ki67_dict[str(cate)] = np.mean(peak_ki67_abs_error_list)
- mae_integration_ki67_dict[str(cate)] = np.mean(integration_ki67_abs_error_list)
- return mean_F1_dict, mean_precision_dict, mean_recall_dict, mean_F1_TMI2016_dict, mean_precision_TMI2016_dict, mean_recall_TMI2016_dict, mae_peak_count_dict, mae_peak_ki67_dict, mae_integration_count_dict, mae_integration_ki67_dict
-
-
- def load_pred_res_from_txt():
- cates_pred_centers = {}
- for classname in ["DAB", "H"]:
- detfile = os.path.join(args.results_save_path, "pred_points_txt/{}-det-point-results.txt".format(classname))
- with open(detfile, 'r') as f:
- lines = f.readlines()
- splitlines = [x.strip().split(' ') for x in lines]
- image_ids = [x[0] for x in splitlines]
- centers = [[[int(float(z)) for z in x[1:]] for x in splitlines]][0]
- this_cate_pred_centers = {}
- for index, image_id in enumerate(image_ids):
- if image_id in this_cate_pred_centers:
- this_cate_pred_centers[image_id].append(centers[index])
- else:
- this_cate_pred_centers.setdefault(image_id, [centers[index]])
- cates_pred_centers.setdefault(classname, this_cate_pred_centers)
- return cates_pred_centers
-
-
- args.dis_thrd = 30
- args.cate_list = ["DAB", "H"]
- result_save_path_root = args.results_save_path
- peak_count_abs_error_list_dict = {"DAB": [], "H": []}
- all_tp_dict, all_fp_dict, npos_dict, precision_list_dict, recall_list_dict, F1_list_dict, pred_peak_count_list_dict, pred_integration_count_list_dict, gt_count_list_dict = center_point_det_results_initialization(
- args, result_save_path_root)
- val_list = [os.path.join(args.input_images_dir, item) for item in os.listdir(args.input_images_dir)]
- val_list.sort()
- test_imgs_path_list = [os.path.basename(item).split(".")[0] for item in val_list]
- real_detect = 0
- for img_i, img_path in enumerate(val_list):
- print(f"Processing {img_i + 1}/{len(test_imgs_path_list)}, path is {img_path}")
- pred_dab_coordinates, pred_h_coordinates = validate(img_path, model)
- if img_i == 0:
- np.save("pos.npy", np.array(pred_dab_coordinates))
- np.save("neg.npy", np.array(pred_h_coordinates))
- print(">>>>>>>>>>>>>npy saved>>>>>>>>>>>>>>>>>>>>>>>>")
- # break
- pred_dab_count = len(pred_dab_coordinates)
- pred_h_count = len(pred_h_coordinates)
- pred_cate_coordinates_list = [pred_dab_coordinates, pred_h_coordinates]
- gt_paths = []
- gt_paths.append(os.path.join('BCData/anotation/test/positive', os.path.basename(img_path).split(".")[0]))
- gt_paths.append(os.path.join('BCData/anotation/test/negative', os.path.basename(img_path).split(".")[0]))
- gt_cates_coordinates = []
- for cate_i, gt_path in enumerate(gt_paths):
- if os.path.exists(gt_path):
- gt_file = h5py.File(gt_path)
- gt_cates_coordinates.append(list(np.asarray(gt_file['coordinates'])))
- else:
- gt_cates_coordinates.append([np.array([-1, -1])])
- gt_dab_count = len(gt_cates_coordinates[0])
- gt_h_count = len(gt_cates_coordinates[1])
- peak_count_abs_error_list_dict["DAB"].append(float(abs(pred_dab_count - gt_dab_count)))
- peak_count_abs_error_list_dict["H"].append(float(abs(pred_h_count - gt_h_count)))
- for cate_i, cate in enumerate(args.cate_list):
- pred_centers = np.array(pred_cate_coordinates_list[cate_i])
- gt_centers = np.array(gt_cates_coordinates[cate_i])
- pred_peak_count_list_dict[args.cate_list[cate_i]].append(len(pred_centers))
- gt_count_list_dict[args.cate_list[cate_i]].append(len(gt_centers))
- pred_integration_count_list_dict[args.cate_list[cate_i]].append(0)
- precision, recall, F1 = center_point_evaluation(args, pred_centers, gt_centers, cate, all_tp_dict, all_fp_dict,
- npos_dict)
- precision_list_dict[str(cate)].append(precision)
- recall_list_dict[str(cate)].append(recall)
- F1_list_dict[str(cate)].append(F1)
- mean_F1_dict, mean_precision_dict, mean_recall_dict, mean_F1_TMI2016_dict, mean_precision_TMI2016_dict, mean_recall_TMI2016_dict, mae_peak_count_dict, mae_peak_ki67_dict, mae_integration_count_dict, mae_integration_ki67_dict = center_point_det_results_reports(
- args, all_tp_dict, all_fp_dict, npos_dict, precision_list_dict, recall_list_dict, F1_list_dict,
- pred_peak_count_list_dict, pred_integration_count_list_dict, gt_count_list_dict)
- print("F1 of positive tumor cells: ", mean_F1_TMI2016_dict["DAB"])
- print("F1 of negative tumor cells: ", mean_F1_TMI2016_dict["H"])
- print("AVG F1: ", np.mean([mean_F1_TMI2016_dict["DAB"], mean_F1_TMI2016_dict["H"]]))
- mae_DAB = np.mean(peak_count_abs_error_list_dict["DAB"])
- mae_H = np.mean(peak_count_abs_error_list_dict["H"])
- print("Finished.")
|