|
- '''
- This project is a PyTorch reimplementation of U-CSRNet.
-
- Acknowledgments
- The source codes of this project borrow heavily from https://github.com/LeeJunHyun/Image_Segmentation
- '''
- import os
- from dataloader import ImageFolder
- import mindspore.dataset as ds
- from model import UCSRNet
- from tools import make_dirs, extract_res, center_point_det_results_initialization, center_point_det_results_reports
- from mindspore.train.serialization import load_param_into_net, load_checkpoint
- from mindspore import context, Tensor
-
-
- # def main(ckpt_path="save_checkpoint/final_best_0.85.ckpt"): # 目前最好
- def main(ckpt_path="save_checkpoint/final_best.ckpt"):
- mode = "test"
- output_root = "results"
- epoch = 36 # searched on validation dataset
- # ckpt_path = "/userhome/xufan/code/res_reimplementation_pytorch/checkpoints/{}.pth.tar".format(epoch)
-
- print("ckpt_path: ", ckpt_path)
- res_save_path = os.path.join(output_root, f"extracted_res/{mode}/res_vis")
- make_dirs(res_save_path)
- res_save_f = open(os.path.join(res_save_path, "res.txt"), 'a')
- test_path = "BCData/images/{}/".format(mode)
- image_paths = list(map(lambda x: os.path.join(test_path, x), os.listdir(test_path)))
- # TODO: Need to search best thrd for pos. and neg. prob map in val dataset, respectively.
- filter_thrd_dict = None
- # filter_thrd_dict = {"positive": 0.15, "negative": 0.12} # searched on validation dataset
- model = UCSRNet(classes=2)
- param_dict = load_checkpoint(ckpt_path) # 如果要换模型的话,修改模型名字
- load_param_into_net(model, param_dict)
- model.set_train(False)
- print(">>>>>>>>>>>>load weight>>>>>>>>>>>>>>")
- test_image = ImageFolder(root=test_path,
- mode=mode,
- augmentation_prob=0,
- type=2)
- test_loader = ds.GeneratorDataset(test_image, column_names=["input", "target", "paths"], shuffle=False)
- test_loader = test_loader.batch(8)
- length = 0
- all_tp_dict, all_fp_dict, all_scores_dict, npos_dict = center_point_det_results_initialization()
- for i, (images, GT, image_paths_index) in enumerate(test_loader):
- if i in [110]:
- print('skip i image:', i)
- continue
-
- print("[{}] processing {}/{}".format(mode, i + 1, test_loader.get_dataset_size()))
- outputs = model(images)
- PD = outputs.asnumpy()
- length += images.shape[0]
- extract_res(res_save_path, mode, epoch, image_paths, image_paths_index, PD, filter_thrd_dict, all_tp_dict,
- all_fp_dict,
- all_scores_dict, npos_dict)
-
- mean_F1_dict, mean_precision_dict, mean_recall_dict = center_point_det_results_reports(all_tp_dict, all_fp_dict,
- all_scores_dict, npos_dict)
- print("mean_F1_dict: {}".format(mean_F1_dict))
- print("mean_precision_dict: {}".format(mean_precision_dict))
- print("mean_recall_dict: {}".format(mean_recall_dict))
- res_save_f.write("mean_F1_dict: {}\n".format(mean_F1_dict))
- res_save_f.write("mean_precision_dict: {}\n".format(mean_precision_dict))
- res_save_f.write("mean_recall_dict: {}\n".format(mean_recall_dict))
- res_save_f.close()
-
-
- def all_test():
- ckpt_dirs = "results/checkpoints"
- ckpt_list = [os.path.join(ckpt_dirs, f) for f in os.listdir(ckpt_dirs) if f[-1] == "t"]
- for ckpt_path in ckpt_list:
- mode = "test"
- output_root = "results"
- epoch = 36 # searched on validation dataset
- # ckpt_path = "/userhome/xufan/code/res_reimplementation_pytorch/checkpoints/{}.pth.tar".format(epoch)
- # ckpt_path = "results/checkpoints/final_best.ckpt"
- print("ckpt_path: ", ckpt_path)
- res_save_path = os.path.join(ckpt_dirs, f"extracted_res/{mode}/res_vis")
- make_dirs(res_save_path)
- res_save_f = open(os.path.join(res_save_path, "res.txt"), 'a')
- test_path = "BCData/images/{}/".format(mode)
- image_paths = list(map(lambda x: os.path.join(test_path, x), os.listdir(test_path)))
- # TODO: Need to search best thrd for pos. and neg. prob map in val dataset, respectively.
- filter_thrd_dict = None
- # filter_thrd_dict = {"positive": 0.15, "negative": 0.12} # searched on validation dataset
- model = UCSRNet(classes=1)
- param_dict = load_checkpoint(ckpt_path) # 如果要换模型的话,修改模型名字
- load_param_into_net(model, param_dict)
- print(">>>>>>>>>>>>load weight>>>>>>>>>>>>>>")
- test_image = ImageFolder(root=test_path,
- mode=mode,
- augmentation_prob=0,
- type=2)
- test_loader = ds.GeneratorDataset(test_image, column_names=["input", "target", "paths"], shuffle=False)
- test_loader = test_loader.batch(8)
- length = 0
- all_tp_dict, all_fp_dict, all_scores_dict, npos_dict = center_point_det_results_initialization()
- for i, (images, GT, image_paths_index) in enumerate(test_loader):
- print("[{}] processing {}/{}".format(mode, i + 1, test_loader.get_dataset_size()))
- outputs = model(images)
- PD = outputs.asnumpy()
- length += images.shape[0]
- extract_res(res_save_path, mode, epoch, image_paths, image_paths_index, PD, filter_thrd_dict, all_tp_dict,
- all_fp_dict,
- all_scores_dict, npos_dict)
- mean_F1_dict, mean_precision_dict, mean_recall_dict = center_point_det_results_reports(all_tp_dict, all_fp_dict,
- all_scores_dict,
- npos_dict)
- print("mean_F1_dict: {}".format(mean_F1_dict))
- print("mean_precision_dict: {}".format(mean_precision_dict))
- print("mean_recall_dict: {}".format(mean_recall_dict))
- res_save_f.write("mean_F1_dict: {}\n".format(mean_F1_dict))
- res_save_f.write("mean_precision_dict: {}\n".format(mean_precision_dict))
- res_save_f.write("mean_recall_dict: {}\n".format(mean_recall_dict))
- res_save_f.close()
-
-
- # my_device_id = 6
- # context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
- # context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=my_device_id)
- context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
- if __name__ == '__main__':
- # test_all()
- main()
|