|
- import os
- from Vargfacenet import WholeNet, VarGFaceNet
- from mindspore import context
- import mindspore.dataset as ds
- from PIL import Image
- import mindspore.dataset.transforms.c_transforms as C2
- import mindspore.dataset.vision.c_transforms as C
- import mindspore.common.dtype as mstype
- from mindspore import Tensor
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- import numpy as np
- from mindspore import nn
- import argparse
- import scipy.io
- import moxing as mox
- import sys
-
- # environment = 'debug'
- environment = 'train'
- if environment == 'debug':
- workroot = '/home/ma-user/work'
- else:
- workroot = '/home/work/user-job-dir'
- print('current work mode:' + environment + ', workroot:' + workroot)
-
- parser = argparse.ArgumentParser(description='Face validation')
- # parser.add_argument('--lr_strategy', type=str, default='default', help='the lr strategy')
- parser.add_argument('--data_url', help='path to training/inference dataset folder', default=workroot + '/data/')
- parser.add_argument('--train_url', help='model folder to save/load', default=workroot + '/model/')
- parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'CPU'])
- parser.add_argument('--chechpoint_name', type=str, default="VargFaceNet_webface_Multistep--70_958.ckpt")
- args_opt = parser.parse_args()
-
-
- class LFW():
- def __init__(self, nameLs, nameRs, flags):
- self.nameLs = nameLs
- self.nameRs = nameRs
- self.flags = flags
-
- def __getitem__(self, index):
- imgL = Image.open(self.nameLs[index]).convert('RGB')
- flip_imgL = imgL.transpose(Image.FLIP_LEFT_RIGHT)
- imgR = Image.open(self.nameRs[index]).convert('RGB')
- flip_imgR = imgR.transpose(Image.FLIP_LEFT_RIGHT)
- return imgL, flip_imgL, imgR, flip_imgR, self.flags[index]
-
- def __len__(self):
- return len(self.nameLs)
-
-
- def create_dataset(nameLs, nameRs, flags):
- dataset = LFW(nameLs, nameRs, flags)
- lfw_ds = ds.GeneratorDataset(dataset, ["imageL", "flip_imageL", "imageR", "flip_imageR", "flag"],
- shuffle=False) # shuffle = False±íʾ°´ÕÒ˳Ðò¶ÁÈ¡
- mean = [0.4914 * 255, 0.4822 * 255, 0.4465 * 255]
- std = [0.2023 * 255, 0.1994 * 255, 0.2010 * 255]
- decode_op = C.Decode()
- resize_op = C.Resize((112, 96))
- normalize_op = C.Normalize(mean=mean, std=std)
- changeswap_op = C.HWC2CHW()
- transform_img = [resize_op, normalize_op, changeswap_op]
-
- type_cast_op = C2.TypeCast(mstype.int32)
- transform_label = [type_cast_op]
-
- lfw_ds = lfw_ds.map(input_columns='imageL', operations=transform_img)
- lfw_ds = lfw_ds.map(input_columns='imageR', operations=transform_img)
- lfw_ds = lfw_ds.map(input_columns='flip_imageL', operations=transform_img)
- lfw_ds = lfw_ds.map(input_columns='flip_imageR', operations=transform_img)
- lfw_ds = lfw_ds.map(input_columns='flag', operations=transform_label)
-
- lfw_ds = lfw_ds.project(columns=["imageL", "flip_imageL", "imageR", "flip_imageR", "flag"])
-
- lfw_ds = lfw_ds.batch(batch_size=40, drop_remainder=True)
- lfw_ds = lfw_ds.repeat(1)
-
- return lfw_ds
-
-
- def parseList(root):
- with open(os.path.join(root, 'pairs.txt')) as f:
- pairs = f.read().splitlines()[1:]
- folder_name = 'lfw-112X96'
- nameLs = []
- nameRs = []
- flags = []
- folds = []
- for i, p in enumerate(pairs):
- p = p.split('\t')
- if len(p) == 3:
- nameL = os.path.join(root, folder_name, p[0], p[0] + '_' + '{:04}.jpg'.format(int(p[1])))
- nameR = os.path.join(root, folder_name, p[0], p[0] + '_' + '{:04}.jpg'.format(int(p[2])))
- flag = 1
- fold = i // 600
- elif len(p) == 4:
- nameL = os.path.join(root, folder_name, p[0], p[0] + '_' + '{:04}.jpg'.format(int(p[1])))
- nameR = os.path.join(root, folder_name, p[2], p[2] + '_' + '{:04}.jpg'.format(int(p[3])))
- flag = -1
- fold = i // 600
- nameLs.append(nameL)
- nameRs.append(nameR)
- flags.append(flag)
- folds.append(fold)
- return [nameLs, nameRs, folds, flags]
-
-
- def getAccuracy(scores, flags, threshold):
- p = np.sum(scores[flags == 1] > threshold)
- n = np.sum(scores[flags == -1] < threshold)
- return 1.0 * (p + n) / len(scores)
-
-
- def getThreshold(scores, flags, thrNum):
- accuracys = np.zeros((2 * thrNum + 1, 1))
- thresholds = np.arange(-thrNum, thrNum + 1) * 1.0 / thrNum
- for i in range(2 * thrNum + 1):
- accuracys[i] = getAccuracy(scores, flags, thresholds[i])
-
- max_index = np.squeeze(accuracys == np.max(accuracys))
- bestThreshold = np.mean(thresholds[max_index])
- return bestThreshold
-
-
- def evaluation_10_fold(feature_save_dir='./result/result1.mat'):
- ACCs = np.zeros(10)
- result = scipy.io.loadmat(feature_save_dir)
- for i in range(10):
- fold = result['fold']
- flags = result['flag']
- featureLs = result['fl']
- featureRs = result['fr']
- valFold = fold != i
- testFold = fold == i
- flags = np.squeeze(flags)
-
- mu = np.mean(np.concatenate((featureLs[valFold[0], :], featureRs[valFold[0], :]), 0), 0)
- featureLs = featureLs - mu
- featureRs = featureRs - mu
- featureLs = featureLs / np.expand_dims(np.sqrt(np.sum(np.power(featureLs, 2), 1)), 1)
- featureRs = featureRs / np.expand_dims(np.sqrt(np.sum(np.power(featureRs, 2), 1)), 1)
-
- scores = np.sum(np.multiply(featureLs, featureRs), 1)
- threshold = getThreshold(scores[valFold[0]], flags[valFold[0]], 10000)
- ACCs[i] = getAccuracy(scores[testFold[0]], flags[testFold[0]], threshold)
-
- return ACCs
-
-
- def getFeatureFromMindspore(lfw_dir, feature_save_dir, weight_file_path='0.ckpt'):
- net = WholeNet(train_phase=False, num_class=10575)
- param_dict = load_checkpoint(weight_file_path)
- net.init_parameters_data()
- load_param_into_net(net, param_dict)
-
- nameLs, nameRs, folds, flags = parseList(lfw_dir)
- lfw_dataset = create_dataset(nameLs, nameRs, flags)
- eval_dataset = lfw_dataset.create_tuple_iterator()
- featureLs = None
- featureRs = None
- net.set_train(False)
- for IL, FIL, IR, FIR, label in eval_dataset:
- featureL = net(IL).asnumpy()
- featureR = net(IR).asnumpy()
- featureFL = net(FIL).asnumpy()
- featureFR = net(FIR).asnumpy()
- featureL = np.concatenate((featureL, featureFL), 1)
- featureR = np.concatenate((featureR, featureFR), 1)
- if featureLs is None:
- featureLs = featureL
- else:
- featureLs = np.concatenate((featureLs, featureL), 0)
- if featureRs is None:
- featureRs = featureR
- else:
- featureRs = np.concatenate((featureRs, featureR), 0)
-
- result = {'fl': featureLs, 'fr': featureRs, 'fold': folds, 'flag': flags}
- scipy.io.savemat(feature_save_dir, result)
-
-
- if __name__ == '__main__':
-
- print('args:')
- print(args_opt)
- data_dir = workroot + '/data'
- train_dir = workroot + '/model'
-
- if not os.path.exists(data_dir):
- os.mkdir(data_dir)
-
- obs_train_url = args_opt.train_url
- train_dir = workroot + '/model/'
- if not os.path.exists(train_dir):
- os.mkdir(train_dir)
- ################################################
-
- if environment == 'train':
- obs_data_url = args_opt.data_url
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url,
- data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- obs_data_url, data_dir) + str(e))
- ################################################
-
- context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
- weight_file = os.path.join(sys.path[0], 'checkpoint', args_opt.chechpoint_name)
- num_dir = os.path.join(sys.path[0], 'result')
- if not os.path.exists(num_dir):
- os.mkdir(num_dir)
- feature_save_dir = os.path.join(sys.path[0], 'rusult.mat')
-
- getFeatureFromMindspore(data_dir, feature_save_dir, weight_file_path=weight_file)
-
- ACCs = evaluation_10_fold(feature_save_dir=feature_save_dir)
-
- txt_path = os.path.join(train_dir, 'result.txt')
- txt_file = open(txt_path, 'w')
-
- for i in range(len(ACCs)):
- txt_str = '{} {:.2f}'.format(i + 1, ACCs[i] * 100)
- print(txt_str)
- txt_file.write(txt_str)
- txt_file.write('\n')
- print("------")
- txt_file.write("--------------")
- txt_file.write('\n')
- txt_file.write('AVE {:.2f}'.format(np.mean(ACCs) * 100))
- print('AVE {:.2f}'.format(np.mean(ACCs) * 100))
- txt_file.close()
-
- if environment == 'train':
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir, obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir, obs_train_url) + str(e))
|