You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

197 lines
7.1 kB

  1. # -*- coding: utf-8 -*-
  2. from __future__ import print_function, division
  3. # import sys
  4. # sys.path.append('/home/xujiahong/openI_benchmark/vechicle_reID_VechicleNet/')
  5. import time
  6. import yaml
  7. import pickle
  8. import torch
  9. import torch.nn as nn
  10. import numpy as np
  11. from torchvision import datasets,transforms
  12. import os
  13. import scipy.io
  14. from tqdm import tqdm
  15. from data_utils.model_train import ft_net
  16. from utils.util import get_stream_logger
  17. from config.mainconfig import OUTPUT_RESULT_DIR, CONFIG_PATH
  18. def fliplr(img):
  19. '''flip horizontal'''
  20. inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W
  21. img_flip = img.index_select(3,inv_idx)
  22. return img_flip
  23. def extract_feature(model, dataloaders, flip):
  24. features = torch.FloatTensor()
  25. count = 0
  26. for _, data in enumerate(tqdm(dataloaders),0):
  27. img, _ = data
  28. n, c, h, w = img.size()
  29. count += n
  30. input_img = img.cuda()
  31. ff = model(input_img)
  32. if flip:
  33. img = fliplr(img)
  34. input_img = img.cuda()
  35. outputs_flip = model(input_img)
  36. ff += outputs_flip
  37. fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
  38. ff = ff.div(fnorm.expand_as(ff))
  39. #print(ff.shape)
  40. features = torch.cat((features,ff.data.cpu().float()), 0)
  41. #features = torch.cat((features,ff.data.float()), 0)
  42. return features
  43. def get_id(img_path):
  44. '''
  45. xjh:
  46. example of the name of the img: 0769_c013_00074310_0
  47. 0769 is the vehicleID, 013 is the cameraID, 00074310 is the frameID
  48. '''
  49. camera_id = []
  50. labels = []
  51. for path, _ in img_path:
  52. #filename = path.split('/')[-1]
  53. filename = os.path.basename(path) #get the name of images
  54. # Test Gallery Image
  55. if not 'c' in filename:
  56. labels.append(9999999)
  57. camera_id.append(9999999)
  58. else:
  59. #label = filename[0:4]
  60. label = filename[0:5] #for benchmark_person
  61. camera = filename.split('c')[1]
  62. if label[0:2]=='-1':
  63. labels.append(-1)
  64. else:
  65. labels.append(int(label))
  66. #camera_id.append(int(camera[0:3]))
  67. camera_id.append(int(camera[0:2]))#for benchmark_person
  68. #print(camera[0:3])
  69. return camera_id, labels
  70. def test(config_file_path:str, logger):
  71. #read config files
  72. with open(config_file_path, encoding='utf-8') as f:
  73. opts = yaml.load(f, Loader=yaml.SafeLoader)
  74. data_dir = opts['input']['dataset']['data_dir']
  75. name = "trained_" + opts['input']['config']['name']
  76. trained_model_name = name + "_last.pth"
  77. save_path = OUTPUT_RESULT_DIR
  78. nclass = opts['input']['config']['nclass']
  79. stride = opts['input']['config']['stride']
  80. pool = opts['input']['config']['pool']
  81. droprate = opts['input']['config']['droprate']
  82. inputsize= opts['input']['config']['inputsize']
  83. w = opts['input']['config']['w']
  84. h = opts['input']['config']['h']
  85. batchsize = opts['input']['config']['batchsize']
  86. flip = opts['test']['flip_test']
  87. trained_model_path = os.path.join(save_path, trained_model_name)
  88. ##############################load model#################################################
  89. ###self-train
  90. model = ft_net(class_num = nclass, droprate = droprate, stride=stride, init_model=None, pool = pool, return_f=False)
  91. try:
  92. model.load_state_dict(torch.load(trained_model_path))
  93. except:
  94. model = torch.nn.DataParallel(model)
  95. model.load_state_dict(torch.load(trained_model_path))
  96. model = model.module
  97. model.classifier.classifier = nn.Sequential() #model ends with feature extractor(output len is 512)
  98. # print(model)
  99. ##############################load dataset###############################################
  100. #transforms for input image h==w==299, inputsize==256
  101. if h == w:
  102. data_transforms = transforms.Compose([
  103. transforms.Resize( ( round(inputsize*1.1), round(inputsize*1.1)), interpolation=3),
  104. transforms.ToTensor(),
  105. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  106. ])
  107. else:
  108. data_transforms = transforms.Compose( [
  109. transforms.Resize((round(h*1.1), round(w*1.1)), interpolation=3), #Image.BICUBIC
  110. transforms.ToTensor(),
  111. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  112. ])
  113. image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['bounding_box_test','query']}
  114. dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batchsize,
  115. shuffle=False, num_workers=8) for x in ['bounding_box_test','query']}
  116. #############################check GPU###################################################
  117. use_gpu = torch.cuda.is_available()
  118. #############################extract features############################################
  119. # Change to test mode
  120. model = model.eval()
  121. if use_gpu:
  122. model = model.cuda()
  123. gallery_path = image_datasets['bounding_box_test'].imgs
  124. query_path = image_datasets['query'].imgs
  125. gallery_cam,gallery_label = get_id(gallery_path)
  126. query_cam,query_label = get_id(query_path)
  127. gallery_label = np.asarray(gallery_label)
  128. query_label = np.asarray(query_label)
  129. gallery_cam = np.asarray(gallery_cam)
  130. query_cam = np.asarray(query_cam)
  131. print('Gallery Size: %d'%len(gallery_label))
  132. print('Query Size: %d'%len(query_label))
  133. # Extract feature
  134. since = time.time()
  135. with torch.no_grad():
  136. gallery_feature = extract_feature(model, dataloaders['bounding_box_test'], flip)
  137. query_feature = extract_feature(model, dataloaders['query'], flip)
  138. process_time = time.time() - since
  139. logger.info('total forward time: %.2f minutes'%(process_time/60))
  140. dist = 1-torch.mm(query_feature, torch.transpose(gallery_feature, 0, 1))
  141. # Save to Matlab for check
  142. extracted_feature = {'gallery_feature': gallery_feature.numpy(), 'gallery_label':gallery_label, 'gallery_cam':gallery_cam, \
  143. 'query_feature': query_feature.numpy(), 'query_label':query_label, 'query_cam':query_cam}
  144. result_name = os.path.join(save_path, name+'_feature.mat')
  145. scipy.io.savemat(result_name, extracted_feature)
  146. return_dict = {}
  147. return_dict['dist'] = dist.numpy()
  148. return_dict['feature_example'] = query_feature[0].numpy()
  149. return_dict['gallery_label'] = gallery_label
  150. return_dict['gallery_cam'] = gallery_cam
  151. return_dict['query_label'] = query_label
  152. return_dict['query_cam'] = query_cam
  153. pickle.dump(return_dict, open(OUTPUT_RESULT_DIR+'test_result.pkl', 'wb'), protocol=4)
  154. return
  155. # eval_result = evaluator(result, logger)
  156. # full_table = display_eval_result(dict = eval_result)
  157. # logger.info(full_table)
  158. if __name__=="__main__":
  159. logger = get_stream_logger('TEST')
  160. test(CONFIG_PATH, logger)

简介

启智社区AI协同平台小白操作指南~~~~ 社区新童鞋们可以参考本项目下的小白训练课程,从单个功能讲解到项目实战,手把手带你了解和上手平台的代码、数据集、云脑、任务等各功能,好用到根本停不下来~!!更有免费的算力哦~!!

Python Markdown

贡献者 (4)