|
- '''
- This project is a PyTorch reimplementation of U-CSRNet.
-
- Acknowledgments
- The source codes of this project borrow heavily from https://github.com/LeeJunHyun/Image_Segmentation
- '''
-
- import os
- import numpy as np
- from skimage.feature import peak_local_max
- import matplotlib.pyplot as plt
- from PIL import Image, ImageDraw
- import h5py
-
- os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
-
-
- def make_dirs(root):
- if not os.path.exists(root):
- os.makedirs(root)
-
-
- def center_point_det_results_initialization():
- all_tp_dict = {}
- all_fp_dict = {}
- all_scores_dict = {}
- npos_dict = {}
- for cate in ["positive", "negative"]:
- all_tp_dict.setdefault(cate, [])
- all_fp_dict.setdefault(cate, [])
- all_scores_dict.setdefault(cate, [])
- npos_dict.setdefault(cate, 0)
- return all_tp_dict, all_fp_dict, all_scores_dict, npos_dict
-
-
- def cal_distance(gt_centers, pred_center):
- return ((np.array(gt_centers[:, 0]) - np.array(pred_center[0])) ** 2 + (
- np.array(gt_centers[:, 1]) - np.array(pred_center[1])) ** 2) ** 0.5
-
-
- def center_point_evaluation(pred_centers, pred_scores, gt_centers, cate, all_tp_dict, all_fp_dict, all_scores_dict,
- npos_dict):
- nd = pred_centers.shape[0]
- nd_gt = gt_centers.shape[0]
- if nd > 0 and nd_gt > 0:
- pred_centers.astype(float)
- pred_scores.astype(float)
- gt_centers.astype(float)
- sorted_ind = np.argsort(-pred_scores)
- pred_centers = pred_centers[sorted_ind, :]
- pred_scores = pred_scores[sorted_ind]
- all_scores_dict[str(cate)].extend(pred_scores)
- tp = np.zeros(nd)
- fp = np.zeros(nd)
- det_flag = [False] * nd_gt
- npos_dict[str(cate)] = npos_dict[str(cate)] + nd_gt
- for d in range(nd):
- distances = cal_distance(gt_centers, pred_centers[d, :])
- min_dis = np.min(distances)
- jmin = np.argmin(distances)
- if min_dis <= 10:
- if not det_flag[jmin]:
- tp[d] = 1.
- det_flag[jmin] = 1
- else:
- fp[d] = 1.
- else:
- fp[d] = 1.
- all_tp_dict[str(cate)].extend(tp)
- all_fp_dict[str(cate)].extend(fp)
- tp_count = np.sum(tp)
- fP_count = np.sum(fp)
- elif nd > 0 and nd_gt == 0:
- tp = np.zeros(nd, dtype=np.float32)
- fp = np.ones(nd, dtype=np.float32)
- all_tp_dict[str(cate)].extend(tp)
- all_fp_dict[str(cate)].extend(fp)
- all_scores_dict[str(cate)].extend(pred_scores)
- elif nd == 0 and nd_gt > 0:
- npos_dict[str(cate)] = npos_dict[str(cate)] + nd_gt
- elif nd == 0 and nd_gt == 0:
- pass
-
-
- def center_point_det_results_reports(all_tp_dict, all_fp_dict, all_scores_dict, npos_dict):
- mean_F1_dict = {}
- mean_precision_dict = {}
- mean_recall_dict = {}
- for cate in ["positive", "negative"]:
- # detection
- all_fp = np.asarray(all_fp_dict[str(cate)])
- all_tp = np.asarray(all_tp_dict[str(cate)])
- all_scores = np.asarray(all_scores_dict[str(cate)])
- npos = npos_dict[str(cate)]
- sorted_ind = np.argsort(-all_scores)
- all_fp = all_fp[sorted_ind]
- all_tp = all_tp[sorted_ind]
- all_fp = np.cumsum(all_fp)
- all_tp = np.cumsum(all_tp)
- rec = all_tp / float(npos)
- prec = all_tp / np.maximum(all_tp + all_fp, np.finfo(np.float64).eps)
- try:
- mean_F1_dict[str(cate)] = 2 * rec[-1] * prec[-1] / (rec[-1] + prec[-1])
- mean_precision_dict[str(cate)] = prec[-1]
- mean_recall_dict[str(cate)] = rec[-1]
- except:
- print("prec: ", prec)
- print("rec: ", rec)
- mean_F1_dict[str(cate)] = 0
- mean_precision_dict[str(cate)] = 0
- mean_recall_dict[str(cate)] = 0
- return mean_F1_dict, mean_precision_dict, mean_recall_dict
-
-
- def extract_res(res_save_path, mode, epoch, image_paths, image_paths_index, prob_map, filter_thrd_dict, all_tp_dict,
- all_fp_dict, all_scores_dict, npos_dict, save_vis_res=False):
- # prob_map = prob_map.asnumpy()
- for j in range(int(prob_map.shape[0])):
- img_path = image_paths[image_paths_index[j]]
- img_name = os.path.basename(img_path)
- if save_vis_res:
- img = Image.open(img_path)
- draw = ImageDraw.Draw(img)
- img_save_path = os.path.join(res_save_path, f"{img_name}")
- pos_pred_maps_save_path = os.path.join(res_save_path, f"{mode}/pred_maps/positive")
- neg_pred_maps_save_path = os.path.join(res_save_path, f"{mode}/pred_maps/negative")
- make_dirs(pos_pred_maps_save_path)
- make_dirs(neg_pred_maps_save_path)
- for cate_i, cate in enumerate(["positive", "negative"]):
- # pred prob map
- tmp_prob_map = prob_map[j, cate_i, ...]
- if save_vis_res:
- plt.imshow(tmp_prob_map)
- if cate == "positive":
- plt.savefig(os.path.join(pos_pred_maps_save_path, f"{img_name}"))
- elif cate == "negative":
- plt.savefig(os.path.join(neg_pred_maps_save_path, f"{img_name}"))
- plt.close()
-
- if filter_thrd_dict:
- pred_centers = peak_local_max(tmp_prob_map, min_distance=10, threshold_abs=filter_thrd_dict[cate],
- exclude_border=False)
- else:
- pred_centers = peak_local_max(tmp_prob_map, min_distance=10, threshold_abs=np.max(tmp_prob_map) / 10.,
- exclude_border=False)
- # pred_centers = peak_local_max(tmp_prob_map, min_distance=10, threshold_abs=np.mean(tmp_prob_map), exclude_border=False)
- pred_centers = np.array([[point[1], point[0]] for point in pred_centers])
-
- # pred score map
- tmp_prob_map_max = np.max(tmp_prob_map)
- if tmp_prob_map_max > 0:
- tmp_score_map = tmp_prob_map / tmp_prob_map_max
- else:
- tmp_score_map = np.zeros(tmp_prob_map.shape, dtype=np.float32)
- pred_scores = []
- for center in pred_centers:
- pred_scores.append(tmp_score_map[center[1], center[0]])
- pred_scores = np.asarray(pred_scores)
-
- # gt prob map
- tmp_gt_path = os.path.join(os.path.dirname(img_path.replace("/images/", "/annotations/")), cate,
- img_name.replace(".png", ".h5"))
- h5_f = h5py.File(tmp_gt_path, "r")
- gt_centers = np.asarray(h5_f["coordinates"])
- h5_f.close()
-
- # vis
- if save_vis_res:
- plt.imshow(img)
- color = list([(0, 0, 255), (0, 255, 0)])[cate_i]
- r1 = 2 # pred
- r2 = 10 # gt
- for point in pred_centers:
- x, y = point[0], point[1]
- draw.ellipse(xy=(x - r1, y - r1, x + r1, y + r1), fill=color, outline=color, width=2)
- for point in gt_centers:
- x, y = point[0], point[1]
- draw.ellipse(xy=(x - r2, y - r2, x + r2, y + r2), fill=None, outline=color, width=4)
-
- # metric
- # print("pred_centers: ", pred_centers)
- # print("gt_centers: ", gt_centers)
- center_point_evaluation(pred_centers, pred_scores, gt_centers, cate, all_tp_dict, all_fp_dict,
- all_scores_dict, npos_dict)
-
- if save_vis_res:
- plt.imshow(img)
- plt.savefig(img_save_path)
- plt.close()
|