|
- import argparse
- import sys
- import os
- import shutil
- import time
- import math
- import logging
- import tensorboardX
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.nn.parallel
- import torch.backends.cudnn as cudnn
- import torch.distributed as dist
- import torch.optim
- import torch.utils.data
- import torch.utils.data.distributed
- import torchvision.transforms as transforms
- import torchvision.datasets as datasets
- import torchvision.models as models
- from datetime import datetime
- from torchvision.models.mobilenet import MobileNetV2
- from models.inceptionresnetv2 import inceptionresnetv2
- from models.adaptive_mobilenet import adaptive_mobilenet_v2
- from models.amc_mbv2 import AMCMobileNetV2
-
- import numpy as np
-
- try:
- from nvidia.dali.plugin.pytorch import DALIClassificationIterator
- from nvidia.dali.pipeline import Pipeline
- import nvidia.dali.ops as ops
- import nvidia.dali.types as types
- except ImportError:
- raise ImportError("Please install DALI from https://www.github.com/NVIDIA/DALI to run this example.")
-
-
- def parse():
- model_names = sorted(name for name in models.__dict__
- if name.islower() and not name.startswith("__")
- and callable(models.__dict__[name]))
-
- parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
- parser.add_argument('data', metavar='DIR', nargs='*', default='/dev/shm/ImageNet',
- help='path(s) to dataset (if one path is provided, it is assumed\n' +
- 'to have subdirectories named "train" and "val"; alternatively,\n' +
- 'train and val paths can be specified directly by providing both paths as arguments)')
- parser.add_argument('--arch', '-a', metavar='ARCH', default='mbv2-adaptive',
- # choices=model_names,
- help='model architecture: ' +
- ' | '.join(model_names) +
- ' (default: resnet18)')
- parser.add_argument('-j', '--workers', default=6, type=int, metavar='N',
- help='number of data loading workers (default: 4)')
- parser.add_argument('--epochs', default=250, type=int, metavar='N',
- help='number of total epochs to run')
- parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
- help='manual epoch number (useful on restarts)')
- parser.add_argument('-b', '--batch-size', default=512, type=int,
- metavar='N', help='mini-batch size per process (default: 256)')
- parser.add_argument('--comment', default=None, type=str, metavar='N',
- help='experiment comment (overwrites automatic comment)')
- parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
- metavar='LR',
- help='Initial learning rate. Will be scaled by <global batch size>/256: '
- 'args.lr = args.lr*float(args.batch_size*args.world_size)/256. '
- 'A warmup schedule will also be applied over the first 5 epochs.')
- parser.add_argument('--lr-decay', type=str, default='cos',
- help='mode for learning rate decay: {\'old\', \'cos\'')
- parser.add_argument('--warmup', default=5, type=int, help='warmup epoch')
- parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
- help='momentum')
- parser.add_argument('--weight-decay', '--wd', default=4e-5, type=float,
- metavar='W', help='weight decay (default: 1e-4)')
- parser.add_argument('--print-freq', '-p', default=10, type=int,
- metavar='N', help='print frequency (default: 10)')
- parser.add_argument('--resume', default='', type=str, metavar='PATH',
- help='path to latest checkpoint (default: none)')
- parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
- help='evaluate model on validation set')
- parser.add_argument('--pretrained', dest='pretrained', action='store_true',
- help='use pre-trained model')
- parser.add_argument('--kd', action='store_true',
- help='Runs KD train')
- parser.add_argument('--dali_cpu', action='store_true',
- help='Runs CPU based version of DALI pipeline.')
- parser.add_argument('--prof', default=-1, type=int,
- help='Only run 10 iterations for profiling.')
- parser.add_argument('--deterministic', action='store_true')
-
- parser.add_argument("--local_rank", default=0, type=int)
- parser.add_argument('--sync_bn', action='store_true', default=False,
- help='enabling apex sync BN.')
-
- parser.add_argument('--opt-level', type=str, default='O1')
- parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
- parser.add_argument('--loss-scale', type=str, default=None)
- parser.add_argument('--channels-last', type=bool, default=False)
- parser.add_argument('-t', '--test', action='store_true',
- help='Launch test mode with preset arguments')
- parser.add_argument('--chcfg', default=None)
- args = parser.parse_args()
- return args
-
-
- # item() is a recent addition, so this helps with backward compatibility.
- def to_python_float(t):
- if hasattr(t, 'item'):
- return t.item()
- else:
- return t[0]
-
-
- class HybridTrainPipe(Pipeline):
- def __init__(self, batch_size, num_threads, device_id, data_dir, crop,
- shard_id, num_shards, dali_cpu=False):
- super(HybridTrainPipe, self).__init__(batch_size,
- num_threads,
- device_id,
- seed=12 + device_id)
- self.jpegs = self.labels = None
- self.input = ops.FileReader(file_root=data_dir,
- shard_id=args.local_rank,
- num_shards=args.world_size,
- random_shuffle=True,
- pad_last_batch=True)
- # let user decide which pipeline works him bets for RN version he runs
- dali_device = 'cpu' if dali_cpu else 'gpu'
- decoder_device = 'cpu' if dali_cpu else 'mixed'
- # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
- # without additional reallocations
- device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
- host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
- self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB,
- device_memory_padding=device_memory_padding,
- host_memory_padding=host_memory_padding)
- self.rrc = ops.RandomResizedCrop(device=dali_device, size=crop, dtype=types.FLOAT,
- interp_type=types.INTERP_TRIANGULAR,
- random_aspect_ratio=[0.75, 1.333333],
- random_area=[0.08, 1.0])
- self.jitter = ops.ColorTwist(device=dali_device)
- self.cmnp = ops.CropMirrorNormalize(device=dali_device,
- dtype=types.FLOAT,
- output_layout=types.NCHW,
- mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
- std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
- self.rng1 = ops.Uniform(range=[0.8, 1.2]) # for brightness, contrast, saturation
- self.rng2 = ops.Uniform(range=[-0.1, 0.1]) # for hue
- self.coin = ops.CoinFlip(probability=0.5)
- logging.info('DALI "{0}" variant'.format(dali_device))
-
- def define_graph(self):
- brightness = self.rng1()
- contrast = self.rng1()
- saturation = self.rng1()
- hue = self.rng2()
- flip = self.coin()
- self.jpegs, self.labels = self.input(name="Reader")
- images = self.decode(self.jpegs)
- images = self.rrc(images)
- images = self.jitter(images, brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
- output = self.cmnp(images.gpu(), mirror=flip)
- return [output, self.labels]
-
-
- class HybridValPipe(Pipeline):
- def __init__(self, batch_size, num_threads, device_id, data_dir, crop,
- size, shard_id, num_shards):
- super(HybridValPipe, self).__init__(batch_size,
- num_threads,
- device_id,
- seed=12 + device_id)
- self.jpegs = self.labels = None
- self.input = ops.FileReader(file_root=data_dir,
- shard_id=args.local_rank,
- num_shards=args.world_size,
- random_shuffle=False,
- pad_last_batch=True)
- self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
- self.res = ops.Resize(device="gpu",
- resize_shorter=size,
- interp_type=types.INTERP_TRIANGULAR)
- self.cmnp = ops.CropMirrorNormalize(device="gpu",
- dtype=types.FLOAT,
- output_layout=types.NCHW,
- crop=(crop, crop),
- mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
- std=[0.229 * 255, 0.224 * 255, 0.225 * 255]
- )
-
- def define_graph(self):
- self.jpegs, self.labels = self.input(name="Reader")
- images = self.decode(self.jpegs)
- images = self.res(images)
- output = self.cmnp(images)
- return [output, self.labels]
-
-
- def get_param_group(model):
- param_group_no_wd = []
- names_no_wd = []
- param_group_normal = []
- arch_parameters = []
-
- for name, m in model.named_modules():
- if isinstance(m, nn.Conv2d):
- if m.bias is not None:
- param_group_no_wd.append(m.bias)
- names_no_wd.append(name + '.bias')
- elif isinstance(m, nn.Linear):
- if m.bias is not None:
- param_group_no_wd.append(m.bias)
- names_no_wd.append(name + '.bias')
- elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
- if m.weight is not None:
- param_group_no_wd.append(m.weight)
- names_no_wd.append(name + '.weight')
- if m.bias is not None:
- param_group_no_wd.append(m.bias)
- names_no_wd.append(name + '.bias')
-
- for name, p in model.named_parameters():
- if (name not in names_no_wd) and (name not in arch_parameters):
- param_group_normal.append(p)
-
- return [{'params': param_group_normal}, {'params': param_group_no_wd, 'weight_decay': 0.0}]
-
-
- class LabelSmoothCELoss(nn.Module):
- def __init__(self, smooth_ratio, num_classes):
- super(LabelSmoothCELoss, self).__init__()
- self.smooth_ratio = smooth_ratio
- self.val = smooth_ratio / num_classes
- self.log_soft = nn.LogSoftmax(dim=1)
-
- def forward(self, x, label):
- one_hot = torch.zeros_like(x)
- one_hot.fill_(self.val)
- y = label.to(torch.long).view(-1, 1)
- one_hot.scatter_(1, y, 1 - self.smooth_ratio + self.val)
-
- loss = -torch.sum(self.log_soft(x) * one_hot.detach()) / x.size(0)
- return loss
-
-
- def main():
- global best_prec1, args, writer
- best_prec1 = 0
- args = parse()
-
- # test mode, use default args for sanity test
- if args.test:
- args.pretrained = True # For evaluating student during test
- # args.opt_level = None
- # args.epochs = 1
- # args.start_epoch = 0
- # args.arch = 'resnet50'
- # args.batch_size = 64
- # args.data = []
- # args.sync_bn = False
- # args.data.append('/data/imagenet/train-jpeg/')
- # args.data.append('/data/imagenet/val-jpeg/')
- # logging.info("Test mode - no DDP, no apex, RN50, 10 iterations")
-
- if args.comment is None:
- args.comment = args.arch + "_" + "kd" if args.kd else "no_kd"
-
- if not len(args.data):
- raise Exception("error: No data set provided")
-
- args.distributed = False
- if 'WORLD_SIZE' in os.environ:
- args.distributed = int(os.environ['WORLD_SIZE']) > 1
-
- logging.debug(args)
-
- # make apex optional
- if args.opt_level is not None or args.distributed or args.sync_bn:
- try:
- global DDP, amp, optimizers, parallel
- from apex.parallel import DistributedDataParallel as DDP
- from apex import amp, optimizers, parallel
- except ImportError:
- raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
-
- logging.info("opt_level = {}".format(args.opt_level))
- logging.info("keep_batchnorm_fp32 = {} {}".format(args.keep_batchnorm_fp32, type(args.keep_batchnorm_fp32)))
- logging.info("loss_scale = {} {}".format(args.loss_scale, type(args.loss_scale)))
-
- logging.info("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
-
- cudnn.benchmark = True
- best_prec1 = 0
- if args.deterministic:
- cudnn.benchmark = False
- cudnn.deterministic = True
- torch.manual_seed(args.local_rank)
- torch.set_printoptions(precision=10)
-
- args.gpu = 0
- args.world_size = 1
-
- if args.distributed:
- args.gpu = args.local_rank
- torch.cuda.set_device(args.gpu)
- torch.distributed.init_process_group(backend='nccl',
- init_method='env://')
- args.world_size = torch.distributed.get_world_size()
-
- args.total_batch_size = args.world_size * args.batch_size
- assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
-
- # create model
- if args.pretrained:
- logging.info("=> using pre-trained model '{}'".format(args.arch))
- model = models.__dict__[args.arch](pretrained=True)
- else:
- logging.info("=> creating model '{}'".format(args.arch))
- if args.arch == "mbv2-0.75":
- model = MobileNetV2(width_mult=0.75)
- if args.arch == "mbv2-1.5":
- model = MobileNetV2(width_mult=1.50)
- elif args.arch == 'mbv2-adaptive':
- model = adaptive_mobilenet_v2(args.chcfg)
- logging.info("mbv2-adaptive, cfg path: %s" % args.chcfg)
- elif args.arch == 'mbv2-amc':
- model = AMCMobileNetV2()
- elif args.arch == 'irnv2':
- model = inceptionresnetv2(num_classes=1000, pretrained='imagenet') # for fine-tuning teacher
- else:
- model = models.__dict__[args.arch]()
- if args.kd:
- logging.info("=> creating teacher '{}'".format("InceptionResNetV2"))
- teacher = inceptionresnetv2(num_classes=1000, pretrained='imagenet').cuda()
-
- if args.sync_bn:
- logging.info("using apex synced BN")
- model = parallel.convert_syncbn_model(model)
-
- if hasattr(torch, 'channels_last') and hasattr(torch, 'contiguous_format'):
- if args.channels_last:
- memory_format = torch.channels_last
- else:
- memory_format = torch.contiguous_format
- model = model.cuda().to(memory_format=memory_format)
- else:
- model = model.cuda()
-
- # Scale learning rate based on global batch size
- args.lr = args.lr * float(args.batch_size * args.world_size) / 256.
- optimizer = torch.optim.SGD(get_param_group(model), args.lr,
- momentum=args.momentum,
- weight_decay=args.weight_decay,
- nesterov=True)
-
- # Initialize Amp. Amp accepts either values or strings for the optional override arguments,
- # for convenient interoperation with argparse.
- if args.opt_level is not None:
- model, optimizer = amp.initialize(model, optimizer,
- opt_level=args.opt_level,
- keep_batchnorm_fp32=args.keep_batchnorm_fp32,
- loss_scale=args.loss_scale)
- # teacher = amp.initialize(model, None, opt_level="O0")
-
- # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
- # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
- # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
- # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
- if args.distributed:
- # By default, apex.parallel.DistributedDataParallel overlaps communication with
- # computation in the backward pass.
- # model = DDP(model)
- # delay_allreduce delays all communication to the end of the backward pass.
- model = DDP(model, delay_allreduce=True)
- if args.kd:
- teacher = DDP(teacher)
-
- # define plain loss function (criterion) for validation
- criterion = nn.CrossEntropyLoss().cuda()
-
- # Optionally resume from a checkpoint
- if args.resume:
- # Use a local scope to avoid dangling references
- def resume():
- if os.path.isfile(args.resume):
- logging.info("=> loading checkpoint '{}'".format(args.resume))
- checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.gpu))
- args.start_epoch = checkpoint['epoch']
- best_prec1 = checkpoint['best_prec1']
- model.load_state_dict(checkpoint['state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- logging.info("=> loaded checkpoint '{}' (epoch {})"
- .format(args.resume, checkpoint['epoch']))
- else:
- logging.info("=> no checkpoint found at '{}'".format(args.resume))
-
- resume()
-
- # Data loading code
- if len(args.data) == 1:
- traindir = os.path.join(args.data[0], 'train')
- valdir = os.path.join(args.data[0], 'val')
- else:
- traindir = args.data[0]
- valdir = args.data[1]
-
- # if args.arch == "inception_v3":
- # raise RuntimeError("Currently, inception_v3 is not supported by this example.")
- # # crop_size = 299
- # # val_size = 320 # I chose this value arbitrarily, we can adjust.
- # else:
- # crop_size = 224
- # val_size = 256
-
- crop_size = 224
- val_size = 256
-
- pipe = HybridTrainPipe(batch_size=args.batch_size,
- num_threads=args.workers,
- device_id=args.local_rank,
- data_dir=traindir,
- crop=crop_size,
- dali_cpu=args.dali_cpu,
- shard_id=args.local_rank,
- num_shards=args.world_size)
- pipe.build()
- train_loader = DALIClassificationIterator(pipe, reader_name="Reader", fill_last_batch=False)
-
- pipe = HybridValPipe(batch_size=args.batch_size,
- num_threads=args.workers,
- device_id=args.local_rank,
- data_dir=valdir,
- crop=crop_size,
- size=val_size,
- shard_id=args.local_rank,
- num_shards=args.world_size)
- pipe.build()
- val_loader = DALIClassificationIterator(pipe, reader_name="Reader", fill_last_batch=False)
-
- if args.test:
- logging.info('*' * 15 + ' Test Start ' + '*' * 15)
- logging.info("Evaluating Teacher...")
- validate(val_loader, teacher, criterion, True)
- val_loader.reset()
- logging.info("Evaluating Student...")
- validate(val_loader, model, criterion, False)
- logging.info('*' * 15 + ' Test Finish ' + '*' * 15)
- return
-
- if args.evaluate:
- validate(val_loader, model, criterion)
- return
-
- total_time = AverageMeter()
- for epoch in range(args.start_epoch, args.epochs):
- # train for one epoch
- if args.kd:
- avg_train_time = train(train_loader, model, optimizer, epoch, True, teacher)
- else:
- avg_train_time = train(train_loader, model, optimizer, epoch)
- total_time.update(avg_train_time)
- if args.test:
- break
-
- # evaluate on validation set
- [prec1, prec5] = validate(val_loader, model, criterion)
-
- # log validation results, remember best prec@1 and save checkpoint
- if args.local_rank == 0:
- train_loader_len = int(math.ceil(train_loader._size / args.batch_size))
- writer.add_scalar("val/acc1", prec1, (epoch + 1) * train_loader_len - 1)
- writer.add_scalar("val/acc5", prec5, (epoch + 1) * train_loader_len - 1)
- is_best = prec1 > best_prec1
- best_prec1 = max(prec1, best_prec1)
- save_checkpoint({
- 'epoch': epoch + 1,
- 'arch': args.arch,
- 'state_dict': model.state_dict(),
- 'best_prec1': best_prec1,
- 'optimizer': optimizer.state_dict(),
- }, is_best, filename=args.comment)
- if epoch == args.epochs - 1:
- logging.info('##Top-1 {0}\n'
- '##Top-5 {1}\n'
- '##Perf {2}'.format(
- prec1,
- prec5,
- args.total_batch_size / total_time.avg))
-
- train_loader.reset()
- val_loader.reset()
- if args.local_rank == 0:
- writer.close()
-
-
- def train(train_loader, model, optimizer, epoch, use_kd=False, teacher=None):
- batch_time = AverageMeter()
- losses = AverageMeter()
- if use_kd:
- losses_cls = AverageMeter()
- losses_kl = AverageMeter()
- top1 = AverageMeter()
- top5 = AverageMeter()
-
- # define loss function (criterion) and optimizer
- # cls_criterion = nn.CrossEntropyLoss().cuda()
- cls_criterion = LabelSmoothCELoss(0.1, 1000).cuda()
- def kl_loss(student_outputs, teacher_outputs, temperature=1):
- # print(student_outputs.shape)
- # print(teacher_outputs.shape)
- loss = nn.KLDivLoss(reduction='batchmean').cuda()(
- F.log_softmax(student_outputs / temperature, dim=1),
- F.softmax(teacher_outputs.detach() / temperature, dim=1)) \
- * (temperature * temperature)
- return loss
-
- def kd_loss(student_outputs, teacher_outputs, labels, distill_weight=0.1):
- return distill_weight * kl_loss(student_outputs, teacher_outputs) \
- + cls_criterion(student_outputs, labels)
-
- kd_criterion = kd_loss
-
- # switch to train mode
- model.train()
- if use_kd:
- teacher.eval()
- end = time.time()
-
- for i, data in enumerate(train_loader):
- stu_input = data[0]["data"]
- tea_input = F.interpolate(stu_input, 299)
-
- target = data[0]["label"].squeeze().cuda().long()
- train_loader_len = int(math.ceil(train_loader._size / args.batch_size))
-
- if args.prof >= 0 and i == args.prof:
- logging.info("Profiling begun at iteration {}".format(i))
- torch.cuda.cudart().cudaProfilerStart()
-
- if args.prof >= 0: torch.cuda.nvtx.range_push("Body of iteration {}".format(i))
-
- adjust_learning_rate(optimizer, epoch, i, train_loader_len, args.warmup)
- if args.test:
- if i > 10:
- break
-
- # compute output
- if args.prof >= 0: torch.cuda.nvtx.range_push("forward")
- stu_output = model(stu_input)
-
- if args.prof >= 0:
- torch.cuda.nvtx.range_pop()
- loss_cls = cls_criterion(stu_output, target)
- if use_kd:
- with torch.no_grad():
- tea_output = teacher(tea_input).detach()
- loss_kl = kl_loss(stu_output, tea_output)
- loss = kd_criterion(stu_output, tea_output, target)
- del tea_output
- else:
- loss = loss_cls
-
- # compute gradient and do SGD step
- optimizer.zero_grad()
-
- if args.prof >= 0: torch.cuda.nvtx.range_push("backward")
- if args.opt_level is not None:
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- else:
- loss.backward()
- if args.prof >= 0: torch.cuda.nvtx.range_pop()
-
- if args.prof >= 0: torch.cuda.nvtx.range_push("optimizer.step()")
- optimizer.step()
- if args.prof >= 0: torch.cuda.nvtx.range_pop()
-
- if i % args.print_freq == 0:
- # Every print_freq iterations, check the loss, accuracy, and speed.
- # For best performance, it doesn't make sense to print these metrics every
- # iteration, since they incur an allreduce and some host<->device syncs.
-
- # Measure accuracy
- prec1, prec5 = accuracy(stu_output.data, target, topk=(1, 5))
-
- # Average loss and accuracy across processes for logging
- if args.distributed:
- reduced_loss = reduce_tensor(loss.data)
- if use_kd:
- reduced_cls_loss = reduce_tensor(loss_cls.data)
- reduced_kl_loss = reduce_tensor(loss_kl.data)
- prec1 = reduce_tensor(prec1)
- prec5 = reduce_tensor(prec5)
- else:
- reduced_loss = loss.data
-
- # to_python_float incurs a host<->device sync
- losses.update(to_python_float(reduced_loss), stu_input.size(0))
- top1.update(to_python_float(prec1), stu_input.size(0))
- top5.update(to_python_float(prec5), stu_input.size(0))
-
- if use_kd:
- losses_cls.update(to_python_float(reduced_cls_loss), stu_input.size(0))
- losses_kl.update(to_python_float(reduced_kl_loss), stu_input.size(0))
-
- torch.cuda.synchronize()
- batch_time.update((time.time() - end) / args.print_freq)
- end = time.time()
-
- if args.local_rank == 0:
- logging.info('Epoch: [{0}][{1}/{2}]\t'
- 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
- 'Speed {3:.3f} ({4:.3f})\t'
- 'LR {lr:.5f}\t'
- 'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
- 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
- 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
- epoch, i, train_loader_len,
- args.world_size * args.batch_size / batch_time.val,
- args.world_size * args.batch_size / batch_time.avg,
- batch_time=batch_time,
- lr=optimizer.param_groups[0]['lr'],
- loss=losses, top1=top1, top5=top5))
- step = epoch * train_loader_len + i
- writer.add_scalar("train/speed", args.world_size * args.batch_size / batch_time.val, step)
- writer.add_scalar("train/lr", optimizer.param_groups[0]['lr'], step)
- if use_kd:
- writer.add_scalar("train/loss_cls", losses_cls.val, step)
- writer.add_scalar("train/loss_kl", losses_kl.val, step)
- writer.add_scalar("train/loss_kd", losses.val, step)
- else:
- writer.add_scalar("train/loss_cls", losses.val, step)
- writer.add_scalar("train/acc1", top1.val, step)
- writer.add_scalar("train/acc5", top5.val, step)
-
- # Pop range "Body of iteration {}".format(i)
- if args.prof >= 0: torch.cuda.nvtx.range_pop()
-
- if args.prof >= 0 and i == args.prof + 10:
- logging.info("Profiling ended at iteration {}".format(i))
- torch.cuda.cudart().cudaProfilerStop()
- quit()
-
- return batch_time.avg
-
-
- def validate(val_loader, model, criterion, is_teacher=False):
- batch_time = AverageMeter()
- losses = AverageMeter()
- top1 = AverageMeter()
- top5 = AverageMeter()
-
- # switch to evaluate mode
- model.eval()
-
- end = time.time()
-
- for i, data in enumerate(val_loader):
- input = data[0]["data"]
- target = data[0]["label"].squeeze().cuda().long()
- val_loader_len = int(val_loader._size / args.batch_size)
-
- # compute output
- with torch.no_grad():
- output = model(input)
- loss = criterion(output, target)
-
- # measure accuracy and record loss
- prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
-
- if args.distributed:
- reduced_loss = reduce_tensor(loss.data)
- prec1 = reduce_tensor(prec1)
- prec5 = reduce_tensor(prec5)
- else:
- reduced_loss = loss.data
-
- losses.update(to_python_float(reduced_loss), input.size(0))
- top1.update(to_python_float(prec1), input.size(0))
- top5.update(to_python_float(prec5), input.size(0))
-
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
-
- # TODO: Change timings to mirror train().
- if args.local_rank == 0 and i % args.print_freq == 0:
- logging.info('Test: [{0}/{1}]\t'
- 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
- 'Speed {2:.3f} ({3:.3f})\t'
- 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
- 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
- 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
- i, val_loader_len,
- args.world_size * args.batch_size / batch_time.val,
- args.world_size * args.batch_size / batch_time.avg,
- batch_time=batch_time, loss=losses,
- top1=top1, top5=top5))
-
- logging.info(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))
-
- return [top1.avg, top5.avg]
-
-
- def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
- torch.save(state, filename)
- if is_best:
- shutil.copyfile(filename, filename.split(".")[0] + '_best.pth.tar')
-
-
- class AverageMeter(object):
- """Computes and stores the average and current value"""
-
- def __init__(self):
- self.reset()
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
-
- def adjust_learning_rate(optimizer, epoch, step, len_epoch, warmup_epoch=5):
- from math import cos, pi
-
- if args.lr_decay == 'old':
- factor = epoch // 30
-
- if epoch >= 80:
- factor = factor + 1
-
- lr = args.lr * (0.1 ** factor)
-
- elif args.lr_decay == 'cos':
- warmup_iter = warmup_epoch * len_epoch
- current_iter = step + epoch * len_epoch
- max_iter = args.epochs * len_epoch
-
- lr = args.lr * (1 + cos(pi * (current_iter - warmup_iter) / (max_iter - warmup_iter))) / 2
-
- """ Warmup with 1/4 base (e.g. 0.1 to 0.4) """
- if epoch < warmup_epoch:
- base_lr = args.lr / 4.
- k = (args.lr - base_lr) / (warmup_epoch * len_epoch)
- lr = base_lr + k * float(1 + step + epoch * len_epoch)
-
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
-
-
- def accuracy(output, target, topk=(1,)):
- """Computes the precision@k for the specified values of k"""
- maxk = max(topk)
- batch_size = target.size(0)
-
- _, pred = output.topk(maxk, 1, True, True)
- pred = pred.t()
- correct = pred.eq(target.view(1, -1).expand_as(pred))
-
- res = []
- for k in topk:
- correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
- res.append(correct_k.mul_(100.0 / batch_size))
- return res
-
-
- def reduce_tensor(tensor):
- rt = tensor.clone()
- dist.all_reduce(rt, op=dist.ReduceOp.SUM)
- rt /= args.world_size
- return rt
-
-
- if __name__ == '__main__':
- writer = None
- if int(os.environ["RANK"]) == 0:
- ts = datetime.now().strftime("%Y%m%d-%H%M%S")
- writer = tensorboardX.SummaryWriter(logdir="./logs/{}".format(ts))
- logging.basicConfig(level=logging.DEBUG, filename="./logs/{}/train.log".format(ts),
- filemode="a+", format="%(asctime)-15s %(levelname)-8s %(message)s")
- log_formatter = logging.Formatter("%(asctime)-15s [%(levelname)-8s] %(message)s")
- console_handler = logging.StreamHandler(sys.stdout)
- console_handler.setFormatter(log_formatter)
- logging.getLogger().addHandler(console_handler)
-
- main()
|