|
- import sys
- sys.path.append("..")
- import os
- import gc
- import pandas as pd
- import random
- import argparse
- import time
- import numpy as np
- import torch
- import torch.nn as nn
- from sklearn.metrics import roc_auc_score
- from sklearn.metrics import roc_curve
- from sklearn.metrics import precision_recall_curve
- from tqdm import tqdm
- from READ_pytorch.datasets import MVTecDataset, CLASS_NAMES, OBJECT, TEXTURE
- from READ_pytorch.datasets import resize_transform_basic, randomflip_transform, rotationflip_transform
- from READ_pytorch.utils import EarlyStop, AverageMeter
- from READ_pytorch.utils import visualize_loc_result
- from READ_pytorch.utils import set_logger, time_file_str
- from READ_pytorch.utils import GPUManager
- from READ_pytorch.ad_algorithm import USTAD
- import json
- import platform
- import matplotlib
- import matplotlib.pyplot as plt
- import logging
- logger = logging.getLogger('READ.Train')
-
- def parse_args():
- """
- Set args parameters
- """
- parser = argparse.ArgumentParser(description='USTAD anomaly detection training.')
- parser.add_argument('--class_name', type=str, default='bottle')
- parser.add_argument('--data_dir', type=str, default='/home/adc/Datasets/mvtec_anomaly_detection', help='Define the data dir')
- parser.add_argument("--save_dir", type=str, default='../ckpts', help="Define where to save model checkpoints.")
- parser.add_argument("--result_dir", type=str, default='../results', help="Define where to save model checkpoints.")
- parser.add_argument('--epochs', type=int, default=100, help='maximum training epochs')
- parser.add_argument('--batch_size', type=int, default=1)
- parser.add_argument('--validation_ratio', type=float, default=0.2)
- parser.add_argument("--n_students", type=int, default=3, help="number of students")
- parser.add_argument("--patch_size", type=list, default=[65, 33, 17], help="patch size list, e.g. [65,33,17] or [65].")
- parser.add_argument('--img_size', type=int, default=256)
- parser.add_argument("--color", type=str, default='RGB', choices=['RGB', 'BGR', 'GRAY'], help="Define original color of training images")
- parser.add_argument("--mean", nargs='+', default=[0.5, 0.5, 0.5], help="Define the mean for image normalization.")
- parser.add_argument("--std", nargs='+', default=[0.5, 0.5, 0.5], help="Define the std for image normalization.")
- parser.add_argument('--lr', type=float, default=1e-4, help='learning rate of Adam')
- parser.add_argument("--optimizer", type=str, default='adam', choices=['adam', 'adabelief', 'radam'], help="Define optimizer.")
- parser.add_argument("--gpu_num", type=int, default=1, help="Define how many gpus used to train this model")
- parser.add_argument("--scheduler", type=str, default='step', choices=['step', 'cosine', None], help="Define scheduler.")
- parser.add_argument("--train", action='store_true', help="Whether to train or not.")
- parser.add_argument("--pretrain", type=str, default=None, help="Location of the pretrained weights.")
- parser.add_argument('--experiment', default=None,
- help='Experiment name (defult None).')
- parser.add_argument("--n_viz", type=int, default=30, help="num of viz results.")
- args = parser.parse_args()
-
- return args
-
-
- def main():
- args = parse_args()
- gpu_list = ",".join([str(x) for x in GPUManager().auto_choice(gpu_num=args.gpu_num)])
- os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_list)
-
- if not os.path.exists(args.save_dir):
- os.makedirs(args.save_dir)
-
- # args.p_crop = 1 if args.crop_size != args.img_resize else 0
- args.input_channel = 1 if args.color == 'GRAY' else 3
- args.mean = [float(x) for x in args.mean]
- args.std = [float(x) for x in args.std]
- args.prefix = time_file_str()
- if args.class_name == 'all':
- DATASET = CLASS_NAMES
- else:
- if args.class_name in CLASS_NAMES:
- DATASET = [args.class_name]
- else:
- raise ValueError('dataset not exists')
-
- test_results = pd.DataFrame()
- result_out = os.path.join(args.result_dir, time.strftime("%m%d_%H%M", time.localtime(time.time())))
- for class_name in DATASET:
- # for class_name in ['transistor', 'wood', 'zipper']:
- test_cache = pd.DataFrame()
- # Init experiment
- if args.experiment == None:
- experi_name = '%s_%s' % (
- 'USTAD',
- class_name
- )
- else:
- experi_name = args.experiment
- out_dir = os.path.abspath(os.path.join(args.save_dir, experi_name))
- if not os.path.exists(out_dir):
- os.makedirs(out_dir)
-
- flip_transform = randomflip_transform(img_size=(args.img_size,args.img_size),
- mean=args.mean,
- std=args.std
- )
- rotflip_transform = rotationflip_transform(img_size=(args.img_size,args.img_size),
- mean=args.mean,
- std=args.std,
- rotation=180,
- border_mode=0)
- cali_transform = resize_transform_basic(img_size=(args.img_size,args.img_size),
- mean=args.mean,
- std=args.std
- )
- test_transform = resize_transform_basic(img_size=(args.img_size,args.img_size),
- mean=args.mean,
- std=args.std
- )
- total_dataset = MVTecDataset(data_path=args.data_dir, class_name=class_name,
- is_train=True, resize=args.img_size, cropsize=args.img_size,
- transform=flip_transform,
- length=None, img_color=args.color)
- cali_data = MVTecDataset(data_path=args.data_dir, class_name=class_name,
- is_train=True, resize=args.img_size, cropsize=args.img_size,
- transform=cali_transform,
- length=None, img_color=args.color)
- img_nums = len(total_dataset)
- valid_num = int(img_nums * args.validation_ratio)
- train_num = img_nums - valid_num
- train_data, val_data = torch.utils.data.random_split(total_dataset, [train_num, valid_num], generator=torch.Generator().manual_seed(752))
- cali_data, _ = torch.utils.data.random_split(cali_data, [train_num, valid_num], generator=torch.Generator().manual_seed(752))
- test_data = MVTecDataset(data_path=args.data_dir, class_name=class_name,
- is_train=False, resize=args.img_size, cropsize=args.img_size,
- transform=test_transform,
- length=None, img_color=args.color)
- loader_kwargs = {'num_workers': 8, 'pin_memory': True} if (torch.cuda.is_available() and platform.system() == 'Linux') else {}
- test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, **loader_kwargs)
- model = USTAD()
- if args.pretrain is not None:
- model.load_weights(args.pretrain)
- if args.train:
- print('Start to train!')
- model.train(train_data, save_path=out_dir, batch_size=args.batch_size,
- val_data=val_data, scheduler=args.scheduler, cali_data=cali_data, epochs=args.epochs)
- else:
- model.est_thres(val_data, expect_fpr=0.005)
-
- ####### Start to Test #######
- scores = []
- test_imgs = []
- gt_list = []
- gt_mask_list = []
- for (data, label, mask) in tqdm(test_dataloader):
- test_imgs.extend(data.cpu().numpy())
- gt_list.extend(label.cpu().numpy())
- gt_mask_list.extend(mask.cpu().numpy())
- img_score, score = model.predict(data)
- scores.extend(score)
-
- scores = np.asarray(scores)
-
- # calculate image-level ROC AUC score
- img_scores = scores.reshape(scores.shape[0], -1).max(axis=1)
- gt_list = np.asarray(gt_list)
- fpr, tpr, _ = roc_curve(gt_list, img_scores)
- img_roc_auc = roc_auc_score(gt_list, img_scores)
- print('image ROCAUC: %.3f' % (img_roc_auc))
-
- # calculate per-pixel level ROCAUC
- gt_mask = np.asarray(gt_mask_list)
- precision, recall, thresholds = precision_recall_curve(gt_mask.flatten().astype('uint8'), scores.flatten())
- a = 2 * precision * recall
- b = precision + recall
- f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
- threshold = thresholds[np.argmax(f1)]
- print('Optimal thres', threshold)
- print('Estimate thres', model.seg_thres)
- fpr, tpr, _ = roc_curve(gt_mask.flatten().astype('uint8'), scores.flatten())
- per_pixel_rocauc = roc_auc_score(gt_mask.flatten().astype('uint8'), scores.flatten())
- print('pixel ROCAUC: %.3f' % (per_pixel_rocauc))
- test_cache['class'] = [class_name]
- test_cache['image ROCAUC'] = [img_roc_auc]
- test_cache['pixel ROCAUC'] = [per_pixel_rocauc]
- test_results = pd.concat([test_results, test_cache])
- plt.plot(fpr, tpr, label='%s ROCAUC: %.3f' % (args.class_name, per_pixel_rocauc))
- plt.legend(loc="lower right")
- save_dir = result_out + '/' + f'USTAD_{class_name}' + '/' + 'pictures_{:.4f}'.format(
- threshold)
- os.makedirs(save_dir, exist_ok=True)
- plt.savefig(os.path.join(save_dir, class_name + '_roc_curve.png'), dpi=100)
-
- visualize_loc_result(args, test_imgs, gt_mask_list, scores, threshold, save_dir, class_name, args.n_viz)
-
- test_results.to_csv(os.path.join(result_out, 'test_results.csv'), index=False)
-
- # collect memory
- del model
- gc.collect()
- torch.cuda.empty_cache()
-
- if __name__ == "__main__":
- main()
|