You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

313 lines
13 kB

  1. # -*- coding: utf-8 -*-
  2. from __future__ import print_function, division
  3. import sys
  4. # sys.path.append('/home/xujiahong/openI_benchmark/BoT_person_reID/')
  5. import yaml
  6. import copy
  7. import torch
  8. import torch.nn as nn
  9. import torch.optim as optim
  10. from torch.optim import lr_scheduler
  11. from torchvision import datasets, transforms
  12. import matplotlib
  13. from data_utils.model_train import ft_net
  14. from data_utils.label_smooth import LSR_loss
  15. from data_utils.triplet import TripletLoss
  16. matplotlib.use('agg')
  17. #import matplotlib.pyplot as plt
  18. #from PIL import Image
  19. import time
  20. import os
  21. from utils.random_erasing import RandomErasing
  22. from utils.model_complexity import compute_model_complexity
  23. import yaml
  24. from utils.autoaugment import ImageNetPolicy
  25. from utils.util import save_network, get_stream_logger
  26. from config.mainconfig import OUTPUT_RESULT_DIR, CONFIG_PATH
  27. from prepare_dir import prepare_dirs
  28. import scipy.io
  29. version = torch.__version__
  30. def train(config_file_path: str, logger):
  31. #phrase yaml file
  32. with open(config_file_path, encoding='utf-8') as f:
  33. opts = yaml.load(f, Loader=yaml.SafeLoader)
  34. data_dir = opts['input']['dataset']['data_dir']
  35. data_name = data_dir.split('/')[-1]
  36. logger.info("dataset name: %s"%(data_name))
  37. nclass = opts['input']['config']['nclass']
  38. num_epochs = opts['input']['config']['num_epochs']
  39. adam = opts['input']['config']['adam']
  40. name = "trained_" + opts['input']['config']['name']
  41. batchsize = opts['input']['config']['batchsize']
  42. inputsize = opts['input']['config']['inputsize']
  43. w = opts['input']['config']['w']
  44. h = opts['input']['config']['h']
  45. stride = opts['input']['config']['stride']
  46. pool = opts['input']['config']['pool']
  47. erasing_p = opts['input']['config']['erasing_p']
  48. lr = opts['input']['config']['lr']
  49. droprate = opts['input']['config']['droprate']
  50. warm_epoch = opts['input']['config']['warm_epoch']
  51. save_path = OUTPUT_RESULT_DIR
  52. ##############################transform###############################
  53. if h == w:
  54. transform_train_list = [
  55. #transforms.RandomRotation(30),
  56. transforms.Resize((inputsize, inputsize), interpolation=3),
  57. transforms.Pad(15),
  58. #transforms.RandomCrop((256,256)),
  59. transforms.RandomResizedCrop(size=inputsize, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
  60. transforms.RandomHorizontalFlip(),
  61. transforms.ToTensor(),
  62. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  63. ]
  64. else:
  65. transform_train_list = [
  66. #transforms.RandomRotation(30),
  67. transforms.Resize((h, w), interpolation=3),
  68. transforms.Pad(15),
  69. #transforms.RandomCrop((256,256)),
  70. transforms.RandomResizedCrop(size=(h, w), scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
  71. transforms.RandomHorizontalFlip(),
  72. transforms.ToTensor(),
  73. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  74. ]
  75. if erasing_p>0:
  76. transform_train_list = transform_train_list + [RandomErasing(probability = erasing_p, mean=[0.0, 0.0, 0.0])]
  77. transform_train_list_aug = [ImageNetPolicy()] + transform_train_list
  78. #print(transform_train_list)
  79. data_transforms = {
  80. 'train': transforms.Compose( transform_train_list ),
  81. 'train_aug': transforms.Compose( transform_train_list_aug ),
  82. }
  83. ######################################################################
  84. # Load Data and pretrained model on VehicleNet
  85. image_datasets = {}
  86. image_datasets['train'] =datasets.ImageFolder(os.path.join(data_dir, 'bounding_box_train'), data_transforms['train'])
  87. dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batchsize,
  88. shuffle=True, num_workers=8, pin_memory=True) # 8 workers may work faster
  89. for x in ['train']}
  90. dataset_sizes = {x: len(image_datasets[x]) for x in ['train']}
  91. use_gpu = torch.cuda.is_available()
  92. print('use_gpu', use_gpu)
  93. #device = torch.device("cuda" if use_gpu else "cpu")
  94. model = ft_net(class_num = nclass, droprate = droprate, stride=stride, init_model=None, pool = pool, return_f=True)
  95. ##########################
  96. #Put model parameter in front of the optimizer!!!
  97. if use_gpu:
  98. # if gpu_ids:
  99. # model = torch.nn.DataParallel(model, device_ids= gpu_ids).cuda()
  100. # ignored_params = list(map(id, model.module.classifier.parameters() ))
  101. # base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
  102. # optimizer_ft = optim.SGD([
  103. # {'params': base_params, 'lr': 0.1*lr},
  104. # {'params': model.module.classifier.parameters(), 'lr': lr}
  105. # ], weight_decay=5e-4, momentum=0.9, nesterov=True)
  106. # else:
  107. # model = model.cuda()
  108. # ignored_params = list(map(id, model.classifier.parameters() ))
  109. # base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
  110. # optimizer_ft = optim.SGD([
  111. # {'params': base_params, 'lr': 0.1*lr},
  112. # {'params': model.classifier.parameters(), 'lr': lr}
  113. # ], weight_decay=5e-4, momentum=0.9, nesterov=True)
  114. model = model.cuda()
  115. ignored_params = list(map(id, model.classifier.parameters() ))
  116. base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
  117. optimizer_ft = optim.SGD([
  118. {'params': base_params, 'lr': 0.1*lr},
  119. {'params': model.classifier.parameters(), 'lr': lr}
  120. ], weight_decay=5e-4, momentum=0.9, nesterov=True)
  121. else:
  122. ignored_params = list(map(id, model.classifier.parameters() ))
  123. base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
  124. optimizer_ft = optim.SGD([
  125. {'params': base_params, 'lr': 0.1*lr},
  126. {'params': model.classifier.parameters(), 'lr': lr}
  127. ], weight_decay=5e-4, momentum=0.9, nesterov=True)
  128. if adam:
  129. optimizer_ft = optim.Adam(model.parameters(), lr, weight_decay=5e-4)
  130. # ignored_params = list(map(id, model.classifier.parameters() ))
  131. # base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
  132. # optimizer_ft = optim.Adam([
  133. # {'params': base_params, 'lr': 0.1*lr},
  134. # {'params': model.classifier.parameters(), 'lr': lr}
  135. # ], weight_decay=5e-4)
  136. # Decay LR by a factor of 0.1 every 40 epochs
  137. #exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=40, gamma=0.1)
  138. exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones=[35], gamma=0.1)
  139. # criterion = nn.CrossEntropyLoss()
  140. criterion = [LSR_loss().cuda(), TripletLoss(margin=0.3).cuda()]
  141. # compute model complexity
  142. if h == w:
  143. params, FLOPs = compute_model_complexity(model, (1,3,inputsize, inputsize), verbose=False, only_conv_linear=True)
  144. else:
  145. params, FLOPs = compute_model_complexity(model, (1,3,h,w), verbose=False, only_conv_linear=True)
  146. logger.info('number of params (M): %.2f'%(params/1e6))
  147. logger.info('FLOPs (G): %.2f'%(FLOPs/1e9))
  148. #####################trian model########################################
  149. y_loss = {} # loss history
  150. y_loss['train'] = []
  151. y_err = {}
  152. y_err['train'] = []
  153. since = time.time()
  154. warm_up = 0.1 # We start from the 0.1*lrRate
  155. warm_iteration = round(dataset_sizes['train']/batchsize)*warm_epoch # first 5 epoch
  156. best_model_wts = model.state_dict()
  157. best_loss = 9999
  158. best_epoch = 0
  159. for epoch in range(num_epochs):
  160. logger.info('Epoch {}/{}'.format(epoch+1, num_epochs))
  161. # Each epoch has a training and validation phase
  162. for phase in ['train']:
  163. if phase == 'train':
  164. #exp_lr_scheduler.step()
  165. model.train(True) # Set model to training mode
  166. else:
  167. model.train(False) # Set model to evaluate mode
  168. running_loss = 0.0
  169. running_corrects = 0.0
  170. # Iterate over data.
  171. idx = 0
  172. #D = [next(iter(dataloaders[phase]))]
  173. #for data in D:
  174. for data in dataloaders[phase]:
  175. idx += 1
  176. # get the inputs
  177. inputs, labels = data #the label is vehicleID
  178. now_batch_size,c,h,w = inputs.shape
  179. if now_batch_size < batchsize: # skip the last batch
  180. continue
  181. #print(inputs.shape)
  182. # wrap them in Variable
  183. if use_gpu:
  184. inputs = inputs.cuda().detach()#uisng detach to show the variable never require gradient
  185. labels = labels.cuda().detach()
  186. # zero the parameter gradients
  187. optimizer_ft.zero_grad()
  188. # forward
  189. if phase == 'val':
  190. with torch.no_grad():
  191. outputs = model(inputs)
  192. else:
  193. outputs = model(inputs)
  194. #####outputs = [feature, fc_result]
  195. # loss = criterion(outputs, labels)
  196. feature = outputs[1] # = x_s
  197. fc = outputs[0]
  198. _, preds = torch.max(fc.data, 1)
  199. loss = (criterion[0](fc, labels) + criterion[1](feature, labels)[0])
  200. if epoch<warm_epoch and phase == 'train':
  201. warm_up = min(1.0, warm_up + 0.9 / warm_iteration)
  202. loss *= warm_up
  203. # backward + optimize only if in training phase
  204. if phase == 'train':
  205. loss.backward()
  206. optimizer_ft.step()
  207. if idx%50 ==0:
  208. logger.info('Iteration:%d loss:%.4f accuracy:%.4f'%(idx, loss.item(), float(torch.sum(preds == labels.data))/now_batch_size ) )
  209. # statistics
  210. if int(version[0])>0 or int(version[2]) > 3: # for the new version like 0.4.0, 0.5.0 and 1.0.0
  211. running_loss += loss.item() * now_batch_size
  212. else : # for the old version like 0.3.0 and 0.3.1
  213. running_loss += loss.data[0] * now_batch_size
  214. running_corrects += float(torch.sum(preds == labels.data))
  215. del(loss, outputs, inputs, preds)
  216. ###########end for###########################
  217. epoch_loss = running_loss / dataset_sizes[phase]
  218. epoch_acc = running_corrects / dataset_sizes[phase]
  219. logger.info('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
  220. y_loss[phase].append(epoch_loss)
  221. y_err[phase].append(1.0-epoch_acc)
  222. # deep copy the model
  223. # if len(gpu_ids)>1:
  224. # save_network(model.module, model_path, name, epoch+1)
  225. # else:
  226. # save_network(model, model_path, name, epoch+1)
  227. save_network(model, save_path, name, epoch+1)
  228. #draw_curve(epoch)
  229. if phase == 'train':
  230. exp_lr_scheduler.step()
  231. time_elapsed = time.time() - since
  232. logger.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  233. if epoch_loss < best_loss:
  234. best_loss = epoch_loss
  235. best_epoch = epoch
  236. best_model_wts = copy.deepcopy(model.state_dict())
  237. logger.info('Best epoch: {:d} Best Train Loss: {:4f}'.format(best_epoch, best_loss))
  238. # load best model weights
  239. model.load_state_dict(best_model_wts)
  240. save_network(model, save_path, name, 'last')
  241. #save error and lossSs
  242. loss_name = 'train_loss.mat'
  243. error_name = 'train_error.mat'
  244. scipy.io.savemat(os.path.join(save_path,loss_name), y_loss)
  245. scipy.io.savemat(os.path.join(save_path,error_name), y_err)
  246. #draw_curve(save_path, start_epoch, num_epochs, y_loss, y_err)
  247. #scipy.io.savemat(os.path.join(save_path,result_name), y_loss)
  248. logger.info('total train time: %.2f minutes'%(time_elapsed / 60))
  249. logger.info('total train epochs: %d epochs'%num_epochs)
  250. if __name__ == "__main__":
  251. prepare_dirs()
  252. logger = get_stream_logger('TRAIN')
  253. train(CONFIG_PATH, logger)

简介

启智社区AI协同平台小白操作指南~~~~ 社区新童鞋们可以参考本项目下的小白训练课程,从单个功能讲解到项目实战,手把手带你了解和上手平台的代码、数据集、云脑、任务等各功能,好用到根本停不下来~!!更有免费的算力哦~!!

Python Markdown

贡献者 (4)