|
- from Utils.Seed_Everything import seed_everything
- import numpy as np
- from Utils.Preprocessing import get_data_set_dual
- from Utils.Seed_Everything import seed_everything
- from Utils.Label_to_Colormap import label_to_colormap
- from Models import HSIViT
-
- import os
- from tqdm import tqdm
- import matplotlib.image as mi
-
- import torch
- import torch.utils.data as data
- from torch.utils.data import DataLoader
-
- from sklearn import metrics
-
- import importlib
- import subprocess
-
- import warnings
- warnings.filterwarnings('ignore')
-
- class HSIdataset(data.Dataset):
- def __init__(self, data_cubes, gt=None, train=False, device='cuda:0'):
- self.data_cubes = data_cubes
- self.gt = gt
- self.train = train
- self.device = device
-
- def __getitem__(self, index):
- data = self.data_cubes[index]
- data = torch.tensor(data.copy(), dtype=torch.float32)
- data = data.unsqueeze(0).permute(0, 3, 1, 2)
- return data
-
- def __len__(self):
- return len(self.data_cubes)
-
- def test_model(data_cubes, test_gt, gt_raw, save_dir, model_name, depth=12, dim=96):
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
- h, w, c = data_cubes[0].shape
- n_class = np.max(gt_raw) + 1
-
- dataset = HSIdataset(data_cubes)
-
- model = HSIViT(img_size=h, patch_size=3, in_chans=1, bands=c, b_patch_size=16,
- num_class=n_class, embed_dim=dim, depth=depth, num_heads=dim // 16,
- sep_pos_embed=True, use_learnable_pos_emb=True, trunc_init=True,
- drop_rate=0., drop_path=0.2).to(device)
-
- ignore_keys = []
- load_keys = []
- state_dict = {}
- model_dict = model.state_dict()
- pretrain_model_para = torch.load(os.path.join(save_dir, model_name), map_location=device)
- for key, v in pretrain_model_para.items():
- if key in model_dict.keys():
- state_dict[key] = v
- load_keys.append(key)
- else:
- ignore_keys.append(key)
- model_dict.update(state_dict)
- model.load_state_dict(model_dict)
-
- print("load_keys:",load_keys)
- print("ignore_keys:",ignore_keys)
- model.eval()
-
- test_dataload = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=0)
-
- pred = np.zeros(1)
- with torch.no_grad():
- for x in tqdm(test_dataload):
- inputs = x.to(device)
- outputs = model(inputs)
- outputs = outputs.detach().cpu().numpy()
- output = np.argmax(outputs[:, 1:], axis=1)
- pred = np.concatenate([pred, output], axis=0)
-
- pred = pred[1:] + 1
- pred = pred.reshape(test_gt.shape)
- # colormap_all = label_to_colormap(pred)
- pred[test_gt == 0] = 0
- colormap_pre = label_to_colormap(pred)
- colormap_test = label_to_colormap(test_gt)
-
- gt_ = test_gt.reshape(-1)
- gt_label = gt_[gt_ != 0] - 1
-
- pred_ = pred.reshape(-1)
- pred_label = pred_[gt_ != 0] - 1
-
- oa = metrics.accuracy_score(gt_label, pred_label)
- aa = np.mean(metrics.recall_score(gt_label, pred_label, average=None))
- kappa = metrics.cohen_kappa_score(gt_label, pred_label)
- ca = metrics.recall_score(gt_label, pred_label, average=None)
-
- # mi.imsave(os.path.join(save_path, model_name.replace('.pkl', '_all_oa_' + str(np.around(oa * 100, 2)) + '.png')), colormap_all)
- mi.imsave(os.path.join(save_path, model_name.replace('.pkl', '_test_' + '.png')), colormap_test)
- mi.imsave(os.path.join(save_path, model_name.replace('.pkl', '_pre_oa_' + str(np.around(oa * 100, 2)) + '.png')), colormap_pre)
-
- # pred = pred[1:] + 1
- # pred = pred.reshape(gt_raw.shape)
- # colormap_all = label_to_colormap(pred)
-
- # pred[gt_raw == 0] = 0
- # colormap = label_to_colormap(pred)
-
- # gt_ = test_gt.reshape(-1)
- # gt_label = gt_[gt_ != 0] - 1
-
- # pred_ = pred.reshape(-1)
- # pred_label = pred_[gt_ != 0] - 1
-
- # # cm = metrics.confusion_matrix(gt_label, pred_label)
- # oa = metrics.accuracy_score(gt_label, pred_label)
- # aa = np.mean(metrics.recall_score(gt_label, pred_label, average=None))
- # kappa = metrics.cohen_kappa_score(gt_label, pred_label)
- # ca = metrics.recall_score(gt_label, pred_label, average=None)
-
- # mi.imsave(os.path.join(save_path, model_name.replace('.pkl', '_all_oa_' + str(np.around(oa * 100, 2)) + '.png')), colormap_all)
- # mi.imsave(os.path.join(save_path, model_name.replace('.pkl', '_oa_' + str(np.around(oa * 100, 2)) + '.png')), colormap)
- return oa, aa, kappa, ca
-
-
- if __name__ == "__main__":
- # # 检查库
- # libs_to_check = {'timm': 'timm', 'scipy': 'scipy', 'sklearn': 'scikit-learn'}
-
- # for lib, import_name in libs_to_check.items():
- # try:
- # importlib.import_module(lib)
- # print(f"{lib} 已安装")
- # except ImportError:
- # print(f"{import_name} 未安装,正在尝试安装...")U
- # subprocess.run(['pip', 'install', import_name])
-
- model_paths = {
- 'PU': ['./Pavia University scene/','PaviaU','PaviaU_gt','./results/Pavia University scene'],
- 'PC': ['./Pavia Centre scene/','Pavia','Pavia_gt','./results/Pavia Centre scene'],
- 'SC': ['./Salinas scene/','Salinas','Salinas_gt','./results/Salinas scene'],
- }
- model_path = model_paths['SC']
-
- seeds = [3407, 3408, 3409, 3410, 3411]
-
- # 调试参数//
- patch_size = 3
- labeled_num = 100
- # 调试参数//
-
- # data_path = r'./Pavia University scene/PaviaU.npy'
- # gt_path = r'./Pavia University scene/PaviaU_gt.npy'
- # save_path = r'./results/Pavia University scene'
- data_path = model_path[0]+model_path[1]+'.npy'
- gt_path = model_path[0]+model_path[2]+'.npy'
- save_path = model_path[3]
-
- model_name = 'HSIMAE'
- model_name = model_name + '_psize_' + str(patch_size) +'_labeled_num_' + str(labeled_num) + '.pkl'
-
- enc_paras = [12, 144] # [12, 144] for HSIMAE-Base, [12, 256] for Large, [12 ,512] for Huge
-
- seed_everything(seeds[0])
- _, _, _, data_cubes, test_gt, gt_raw = get_data_set_dual(data_path,
- gt_path,
- patch_size=patch_size,
- num=labeled_num,
- norm=False)
-
- test_results = []
- test_results_per_class = []
- seed_everything(seeds[1])
- oa, aa, kappa, ca = test_model(data_cubes = data_cubes,
- test_gt = test_gt,
- gt_raw = gt_raw,
- save_dir = save_path,
- model_name = model_name,
- depth=enc_paras[0],
- dim=enc_paras[1],
- )
- # print('test oa, aa, kappa:')
- print("测试集精度:{},{},{},{}".format(oa,aa,kappa,ca))
- # print(oa, aa, kappa, ca)
-
- # test_results.append([oa, aa, kappa])
- # test_results_per_class.append(ca)
-
- # test_results = np.array(test_results)
- # test_mean = np.mean(test_results, axis=0)
- # test_std = np.std(test_results, axis=0)
-
- # test_results_per_class = np.array(test_results_per_class)
- # class_accuracy_mean = np.mean(test_results_per_class, axis=0) * 100
- # class_accuracy_std = np.std(test_results_per_class, axis=0) * 100
-
- # results = [class_accuracy_mean, test_mean, test_std]
-
- # print('test oa, aa, kappa:')
- # print(results[1])
-
- # print('class_accuracy:')
- # for ca in results[0]:
- # print(np.around(ca, 2))
-
- # print('test oa, aa, kappa: ')
- # for mean in results[1]:
- # print(np.around(mean * 100, 2))
- # for var in results[2]:
- # print(np.around(var * 100, 2))
|