|
- '''
- This project is a PyTorch reimplementation of U-CSRNet.
-
- Acknowledgments
- The source codes of this project borrow heavily from https://github.com/LeeJunHyun/Image_Segmentation
- '''
- import os
- from dataloader import ImageFolder
- import mindspore.dataset as ds
- from model import UCSRNet
- from tools import make_dirs, extract_res, center_point_det_results_initialization, center_point_det_results_reports
- from mindspore.train.serialization import load_param_into_net, load_checkpoint
- import numpy as np
-
-
- def main():
- mode = "test"
- output_root = "results"
- epoch = 36 # searched on validation dataset
- # ckpt_path = "/userhome/xufan/code/res_reimplementation_pytorch/checkpoints/{}.pth.tar".format(epoch)
- ckpt_path_p = "results/checkpoints_p/final_best.ckpt"
- print("ckpt_path for positive: ", ckpt_path_p)
- ckpt_path_n = "results/checkpoints_n/final_best.ckpt"
- print("ckpt_path for negative: ", ckpt_path_n)
- res_save_path = os.path.join(output_root, f"extracted_res/{mode}/res_vis")
- make_dirs(res_save_path)
- res_save_f = open(os.path.join(res_save_path, "res.txt"), 'a')
- test_path = "BCData/images/{}/".format(mode)
- image_paths = list(map(lambda x: os.path.join(test_path, x), os.listdir(test_path)))
- # TODO: Need to search best thrd for pos. and neg. prob map in val dataset, respectively.
- filter_thrd_dict = None
- # filter_thrd_dict = {"positive": 0.15, "negative": 0.12} # searched on validation dataset
- model_p = UCSRNet(classes=1)
- param_dict = load_checkpoint(ckpt_path_p) # 如果要换模型的话,修改模型名字
- load_param_into_net(model_p, param_dict)
-
- model_n = UCSRNet(classes=1)
- param_dict = load_checkpoint(ckpt_path_n) # 如果要换模型的话,修改模型名字
- load_param_into_net(model_n, param_dict)
-
- print(">>>>>>>>>>>>load weight>>>>>>>>>>>>>>")
- test_image = ImageFolder(root=test_path,
- mode=mode,
- augmentation_prob=0,
- type=0)
- test_loader = ds.GeneratorDataset(test_image, column_names=["input", "target", "paths"], shuffle=False)
- test_loader = test_loader.batch(1)
- length = 0
- all_tp_dict, all_fp_dict, all_scores_dict, npos_dict = center_point_det_results_initialization()
- for i, (images, GT, image_paths_index) in enumerate(test_loader):
- print("[{}] processing {}/{}".format(mode, i + 1, test_loader.get_dataset_size()))
- outputs_p = model_p(images)
- PD_p = outputs_p.asnumpy()
-
- outputs_n = model_n(images)
- PD_n = outputs_n.asnumpy()
-
- PD = np.concatenate((PD_p, PD_n), axis=1)
-
- length += images.shape[0]
- extract_res(res_save_path, mode, epoch, image_paths, image_paths_index, PD, filter_thrd_dict, all_tp_dict,
- all_fp_dict,
- all_scores_dict, npos_dict)
- if i == 2:
- break
- mean_F1_dict, mean_precision_dict, mean_recall_dict = center_point_det_results_reports(all_tp_dict, all_fp_dict,
- all_scores_dict, npos_dict)
- print("mean_F1_dict: {}".format(mean_F1_dict))
- print("mean_precision_dict: {}".format(mean_precision_dict))
- print("mean_recall_dict: {}".format(mean_recall_dict))
- res_save_f.write("mean_F1_dict: {}\n".format(mean_F1_dict))
- res_save_f.write("mean_precision_dict: {}\n".format(mean_precision_dict))
- res_save_f.write("mean_recall_dict: {}\n".format(mean_recall_dict))
- res_save_f.close()
-
-
- if __name__ == '__main__':
- main()
|