|
- # -*- coding:utf-8 -*-
- import os
- import torch
- from torchvision import transforms
- import numpy as np
- import cv2
- import json
- from safety_train import getModel
-
- import argparse
-
- # Training settings
- parser = argparse.ArgumentParser(description='PyTorch Model Safety Testing Python Example')
- #The dataset location is placed under /dataset
- parser.add_argument('--testdata', default="/dataset" ,help='path to test dataset')
- parser.add_argument('--ckpt_url', default="", help='pretrain model path')
-
- def get_pic_from_dir(dir_path, transform):
- subPath = os.listdir(dir_path)
- for tmp in subPath:
- print(tmp)
-
- if len(subPath) == 1 and subPath[0] != "images":
- dir_path = os.path.join(dir_path, subPath[0])
- if os.path.exists(os.path.join(dir_path, "inputs.npy")):
- samples = np.load(os.path.join(dir_path, "inputs.npy"))
- return samples
- elif os.path.exists(os.path.join(dir_path, "images/")):
- img_data = []
- dir_path += "/images"
- for file_name in os.listdir(dir_path):
- print(file_name)
- img = cv2.imread(os.path.join(dir_path, file_name))
- img = transform(img)
- img = img.numpy()
- print(img.shape)
- img_data.append(img)
- return np.array(img_data)
- else:
- raise Exception(
- "The path {} do not has valid dataset {}".format(
- dir_path, os.listdir(dir_path)
- )
- )
-
- if __name__ == '__main__':
- args, unknown = parser.parse_known_args()
- print('cuda is available:{}'.format(torch.cuda.is_available()))
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- # 选择的模型管理中模型文件位置
- modelpath = args.ckpt_url
- print(modelpath)
- # *************************** 个人修改开始 ***************************************
- # 模型网络定义,getModel是从训练模型的定义中获取。
- model = getModel()
-
- # 数据集数据转换,确保与模型训练的一致。
- transform = transforms.Compose(
- [
- transforms.ToPILImage(),
- transforms.Resize([32, 32]),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
- ]
- )
- # *************************** 个人修改结束 ***************************************
-
- model.load_state_dict(torch.load(modelpath))
- model.eval()
- model.to(device)
- # 处理系统提供的数据基础数据集及对抗数据集
- dataset = args.testdata
- c_dataset = args.testdata
-
- for file_name in os.listdir(args.testdata):
- print(file_name)
- if file_name =="ImageNet1000_100_FGSM":
- c_dataset = args.testdata + "/ImageNet1000_100_FGSM/fgsm_ImageNet1000_100/"
- dataset = args.testdata + "/ImageNet1000_100/ImageNet1000_100/ImageNet1000_100/"
- elif file_name =="CIFAR10_1000_FGSM":
- c_dataset = args.testdata + "/CIFAR10_1000_FGSM/fgsm_cifar10_1000/"
- dataset = args.testdata + "/CIFAR10_1000/cifar10_1000/cifar10_1000/"
-
- print(c_dataset)
- print(dataset)
-
- save_path = "/result"
- save_name = "result.json"
- result_dic = {"model": {}}
- print("inference base dataset ")
- # 推理基础数据集
- origin_data = get_pic_from_dir(dataset, transform)
- with torch.no_grad():
- ret = model(torch.from_numpy(origin_data).float().to(device))
- result_dic["model"]["BDResult"] = ret.tolist()
-
- print("inference c_dataset")
- # 推理攻击样本
- for c_dir in os.listdir(c_dataset):
- if os.path.isdir(c_dataset + "/" + c_dir):
- child_data = get_pic_from_dir(os.path.join(c_dataset, c_dir), transform)
- with torch.no_grad():
- ret = model(torch.from_numpy(child_data).float().to(device))
- if "CDResult" not in result_dic["model"]:
- result_dic["model"]["CDResult"] = {}
- result_dic["model"]["CDResult"][c_dir] = ret.tolist()
-
- #print(result_dic)
-
- # 保存预测结果
- if not os.path.exists(save_path):
- os.makedirs(save_path)
- with open(os.path.join(save_path, save_name), "w") as f:
- json.dump(result_dic, f)
|