From 012639736376da6541d1c6240da06c2a17ab0d33 Mon Sep 17 00:00:00 2001 From: xfey Date: Sat, 18 Jun 2022 13:47:36 +0800 Subject: [PATCH 1/7] bugs fixing --- .../fbresnet_imagenet/fbresnet.py | 2 +- xnas/datasets/imagenet.py | 47 +++++++------- xnas/datasets/loader.py | 64 +++++++++++-------- 3 files changed, 62 insertions(+), 51 deletions(-) diff --git a/xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet.py b/xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet.py index 96ea1b4..60f6c7d 100644 --- a/xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet.py +++ b/xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet.py @@ -7,7 +7,7 @@ import math import torch.utils.model_zoo as model_zoo import torch -WEIGHT_PATH = 'teacher_model/fbresnet_imagenet/fbresnet152.pth' +WEIGHT_PATH = 'xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet152.pth' __all__ = ['FBResNet', #'fbresnet18', 'fbresnet34', 'fbresnet50', 'fbresnet101', diff --git a/xnas/datasets/imagenet.py b/xnas/datasets/imagenet.py index 3a6bf73..f565f91 100644 --- a/xnas/datasets/imagenet.py +++ b/xnas/datasets/imagenet.py @@ -28,8 +28,9 @@ class ImageFolder(): def __init__( self, datapath, - split, - batch_size=None, + batch_size, + split=None, + use_val=False, dataset_name='imagenet', _rgb_normalized_mean=None, _rgb_normalized_std=None, @@ -41,9 +42,9 @@ class ImageFolder(): datapath = './data/imagenet/' if not datapath else datapath assert os.path.exists(datapath), "Data path '{}' not found".format(datapath) - self.use_val = cfg.LOADER.USE_VAL - self._data_path, self._split, self.dataset_name = datapath, split, dataset_name - self._rgb_normalized_mean, self._rgb_normalized_std = _rgb_normalized_mean, _rgb_normalized_std + self.use_val = use_val + self.data_path, self.split, self.dataset_name = datapath, split, dataset_name + self.rgb_normalized_mean, self.rgb_normalized_std = _rgb_normalized_mean, _rgb_normalized_std self.num_workers = cfg.LOADER.NUM_WORKERS if num_workers is None else num_workers self.pin_memory = cfg.LOADER.PIN_MEMORY if pin_memory is None else pin_memory self.shuffle = shuffle @@ -62,7 +63,7 @@ class ImageFolder(): else: self.transforms = transforms if not self.use_val: - assert len(self.transforms) == len(self._split), "Length of split and transforms should be equal" + assert len(self.transforms) == len(self.split), "Length of split and transforms should be equal" else: assert len(self.transforms) == 2 @@ -80,15 +81,15 @@ class ImageFolder(): def _construct_imdb(self): # Images are stored per class in subdirs (format: n) if not self.use_val: - split_files = os.listdir(self._data_path) + split_files = os.listdir(self.data_path) else: - split_files = os.listdir(os.path.join(self._data_path, "train")) + split_files = os.listdir(os.path.join(self.data_path, "train")) if self.dataset_name == "imagenet": # imagenet format folder names self._class_ids = sorted( f for f in split_files if re.match(r"^n[0-9]+$", f)) - self._rgb_normalized_mean = [0.485, 0.456, 0.406] - self._rgb_normalized_std = [0.229, 0.224, 0.225] + self.rgb_normalized_mean = [0.485, 0.456, 0.406] + self.rgb_normalized_std = [0.229, 0.224, 0.225] elif self.dataset_name == 'custom': self._class_ids = sorted( f for f in split_files if not f[0] == '.') @@ -102,7 +103,7 @@ class ImageFolder(): self._imdb = [] for class_id in self._class_ids: cont_id = self._class_id_cont_id[class_id] - train_im_dir = os.path.join(self._data_path, class_id) + train_im_dir = os.path.join(self.data_path, class_id) for im_name in os.listdir(train_im_dir): im_path = os.path.join(train_im_dir, im_name) if is_image_file(im_path): @@ -112,8 +113,8 @@ class ImageFolder(): else: self._train_imdb = [] self._val_imdb = [] - train_path = os.path.join(self._data_path, "train") - val_path = os.path.join(self._data_path, "val") + train_path = os.path.join(self.data_path, "train") + val_path = os.path.join(self.data_path, "val") for class_id in self._class_ids: cont_id = self._class_id_cont_id[class_id] train_im_dir = os.path.join(train_path, class_id) @@ -138,15 +139,15 @@ class ImageFolder(): data_loaders = [] pre_partition = 0. pre_index = 0 - for i, _split in enumerate(self._split): + for i, _split in enumerate(self.split): _current_partition = pre_partition + _split _current_index = int(len(self._imdb) * _current_partition) _current_indices = indices[pre_index: _current_index] assert not len(_current_indices) == 0, "The length of indices is zero!" dataset = ImageList_torch([self._imdb[j] for j in _current_indices], self.msrc, # add support for multisize_random_crop - _rgb_normalized_mean=self._rgb_normalized_mean, - _rgb_normalized_std=self._rgb_normalized_std, + _rgb_normalized_mean=self.rgb_normalized_mean, + _rgb_normalized_std=self.rgb_normalized_std, **self.transforms[i]) sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None loader = self.loader(dataset, @@ -163,8 +164,8 @@ class ImageFolder(): train_dataset = ImageList_torch( self._train_imdb, self.msrc, - _rgb_normalized_mean=self._rgb_normalized_mean, - _rgb_normalized_std=self._rgb_normalized_std, + _rgb_normalized_mean=self.rgb_normalized_mean, + _rgb_normalized_std=self.rgb_normalized_std, **self.transforms[0] ) sampler = DistributedSampler(train_dataset) if cfg.NUM_GPUS > 1 else None @@ -177,8 +178,8 @@ class ImageFolder(): val_dataset = ImageList_torch( self._val_imdb, self.msrc, - _rgb_normalized_mean=self._rgb_normalized_mean, - _rgb_normalized_std=self._rgb_normalized_std, + _rgb_normalized_mean=self.rgb_normalized_mean, + _rgb_normalized_std=self.rgb_normalized_std, **self.transforms[1] ) sampler = DistributedSampler(val_dataset) if cfg.NUM_GPUS > 1 else None @@ -209,8 +210,8 @@ class ImageList_torch(torch.utils.data.Dataset): random_flip=False): self._imdb = _list self.msrc = msrc - self._bgr_normalized_mean = _rgb_normalized_mean[::-1] - self._bgr_normalized_std = _rgb_normalized_std[::-1] + self._rgb_normalized_mean = _rgb_normalized_mean + self._rgb_normalized_std = _rgb_normalized_std self.crop = crop self.crop_size = crop_size self.min_crop = min_crop @@ -234,7 +235,7 @@ class ImageList_torch(torch.utils.data.Dataset): if self.random_flip: transforms.append(torch_transforms.RandomHorizontalFlip()) transforms.append(torch_transforms.ToTensor()) - transforms.append(torch_transforms.Normalize(mean=self._bgr_normalized_mean, std=self._bgr_normalized_std)) + transforms.append(torch_transforms.Normalize(mean=self._rgb_normalized_mean, std=self._rgb_normalized_std)) self.transform = torch_transforms.Compose(transforms) def __getitem__(self, index): diff --git a/xnas/datasets/loader.py b/xnas/datasets/loader.py index 28a80a4..d0266f3 100644 --- a/xnas/datasets/loader.py +++ b/xnas/datasets/loader.py @@ -31,16 +31,16 @@ def construct_loader( split = cfg.LOADER.SPLIT name = cfg.LOADER.DATASET - batch_size = cfg.LOADER.BATCH_SIZE + batch_size = cfg.LOADER.BATCH_SIZE datapath = cfg.LOADER.DATAPATH assert (name in SUPPORTED_DATASETS) or (name in IMAGEFOLDER_FORMAT), "dataset not supported." # expand batch_size to support different number during training & validating if isinstance(batch_size, int): - batch_size = [batch_size, batch_size] + batch_size = [batch_size] * len(split) elif batch_size is None: - batch_size = [256, 256] + batch_size = [256] * len(split) assert len(batch_size) == len(split), "lengths of batch_size and split should be same." # check if randomresized crop is used only in ImageFolder type datasets @@ -52,9 +52,10 @@ def construct_loader( train_data, _ = get_data(name, datapath, cutout_length, use_classes=use_classes, transforms=transforms) return split_dataloader(train_data, batch_size, split) elif name in IMAGEFOLDER_FORMAT: + assert cfg.LOADER.USE_VAL is False, "do not get normal dataloaders." return ImageFolder( # using path of training data of ImageNet as `datapath` - datapath, split, batch_size=batch_size, - transforms=transforms, + datapath, batch_size=batch_size, split=split, + use_val=False, transforms=transforms, ).generate_data_loader() else: print("dataset not supported.") @@ -137,7 +138,6 @@ def get_normal_dataloader( name=None, train_batch=None, cutout_length=0, - download=True, use_classes=None, transforms=None, ): @@ -147,28 +147,38 @@ def get_normal_dataloader( root=cfg.LOADER.DATAPATH test_batch=cfg.TEST.BATCH_SIZE - # get normal dataloaders with train&test subsets. - train_data, test_data = get_data(name, root, cutout_length, download, use_classes, transforms) - - # if loader.batch_size is a list for [train, val_1, ...], the first value will be used. - if isinstance(train_batch, list): - train_batch = train_batch[0] + assert (name in SUPPORTED_DATASETS) or (name in IMAGEFOLDER_FORMAT), "dataset not supported." + assert isinstance(train_batch, int), "normal dataloader using single training batch-size, not list." + # check if randomresized crop is used only in ImageFolder type datasets + if len(cfg.SEARCH.MULTI_SIZES): + assert name in IMAGEFOLDER_FORMAT, "RandomResizedCrop can only be used in ImageFolder currently." + + if name in SUPPORTED_DATASETS: + # get normal dataloaders with train&test subsets. + train_data, test_data = get_data(name, root, cutout_length, use_classes=use_classes, transforms=transforms) - train_loader = data.DataLoader( - dataset=train_data, - batch_size=train_batch, - shuffle=True, - num_workers=cfg.LOADER.NUM_WORKERS, - pin_memory=cfg.LOADER.PIN_MEMORY, - ) - test_loader = data.DataLoader( - dataset=test_data, - batch_size=test_batch, - shuffle=False, - num_workers=cfg.LOADER.NUM_WORKERS, - pin_memory=cfg.LOADER.PIN_MEMORY, - ) - return train_loader, test_loader + train_loader = data.DataLoader( + dataset=train_data, + batch_size=train_batch, + shuffle=True, + num_workers=cfg.LOADER.NUM_WORKERS, + pin_memory=cfg.LOADER.PIN_MEMORY, + ) + test_loader = data.DataLoader( + dataset=test_data, + batch_size=test_batch, + shuffle=False, + num_workers=cfg.LOADER.NUM_WORKERS, + pin_memory=cfg.LOADER.PIN_MEMORY, + ) + return train_loader, test_loader + elif name in IMAGEFOLDER_FORMAT: + assert cfg.LOADER.USE_VAL is True, "getting normal dataloader." + return ImageFolder( # using path of training data of ImageNet as `datapath` + root, batch_size=[train_batch, test_batch], + use_val=True, + transforms=transforms, + ).generate_data_loader() def split_dataloader(data_, batch_size, split): -- 2.34.1 From daec5229d64d72e8a08a21a112771838e9a9531f Mon Sep 17 00:00:00 2001 From: LeiZhang Date: Mon, 20 Jun 2022 10:27:05 +0800 Subject: [PATCH 2/7] (Feature): RMINAS support mobilnetV2(proxyless) space & (BUG): fix imagenet RGB normalize bug --- .../RMINAS/rminas_proxyless_imagenet.yaml | 21 ++ scripts/search/RMINAS.py | 60 ++-- xnas/algorithms/RMINAS/sampler/RF_sampling.py | 48 +++- .../fbresnet_imagenet/fbresnet.py | 2 +- xnas/algorithms/RMINAS/utils/random_data.py | 12 +- xnas/core/builder.py | 4 + xnas/datasets/imagenet.py | 21 +- xnas/spaces/OFA/ProxylessNet/cnn.py | 24 +- xnas/spaces/ProxylessNAS/cnn.py | 268 ++++++++++++++++++ 9 files changed, 409 insertions(+), 51 deletions(-) create mode 100644 configs/search/RMINAS/rminas_proxyless_imagenet.yaml create mode 100644 xnas/spaces/ProxylessNAS/cnn.py diff --git a/configs/search/RMINAS/rminas_proxyless_imagenet.yaml b/configs/search/RMINAS/rminas_proxyless_imagenet.yaml new file mode 100644 index 0000000..a0e3007 --- /dev/null +++ b/configs/search/RMINAS/rminas_proxyless_imagenet.yaml @@ -0,0 +1,21 @@ +SPACE: + NAME: 'proxyless' +LOADER: + DATASET: 'imagenet' + NUM_CLASSES: 10 + NUM_WORKERS: 0 + BATCH_SIZE: 128 +OPTIM: + BASE_LR: 0.025 + MOMENTUM: 0.9 + WEIGHT_DECAY: 0.0003 + MAX_EPOCH: 500 +TRAIN: + CHANNELS: 16 + LAYERS: 8 +RMINAS: + LOSS_BETA: 0.8 + RF_WARMUP: 100 + RF_THRESRATE: 0.05 + RF_SUCC: 100 +OUT_DIR: 'exp/rminas' \ No newline at end of file diff --git a/scripts/search/RMINAS.py b/scripts/search/RMINAS.py index 5174662..1d46cd6 100755 --- a/scripts/search/RMINAS.py +++ b/scripts/search/RMINAS.py @@ -4,7 +4,7 @@ import time import numpy as np import torch - +from fvcore.nn import FlopCountAnalysis import xnas.core.config as config import xnas.logger.logging as logging from xnas.core.config import cfg @@ -38,6 +38,8 @@ def rminas_hp_builder(): RF_space = 'nasbenchmacro' from xnas.evaluations.NASBenchMacro.evaluate import evaluate, data api = data + elif cfg.SPACE.NAME == 'proxyless': + RF_space = 'proxyless' # for example : arch = '00000000' # arch = '' # evaluate(arch) @@ -48,7 +50,7 @@ def main(): rminas_hp_builder() - assert cfg.SPACE.NAME in ['infer_nb201', 'infer_darts',"nasbenchmacro"] + assert cfg.SPACE.NAME in ['infer_nb201', 'infer_darts',"nasbenchmacro", "proxyless"] assert cfg.LOADER.DATASET in ['cifar10', 'cifar100', 'imagenet', 'imagenet16_120'], 'dataset error' if cfg.LOADER.DATASET == 'cifar10': @@ -64,7 +66,7 @@ def main(): network.load_state_dict(torch.load('xnas/algorithms/RMINAS/teacher_model/resnet101_cifar100/resnet101.pth')) elif cfg.LOADER.DATASET == 'imagenet': - assert cfg.SPACE.NAME == 'infer_darts' + assert cfg.SPACE.NAME in ('infer_darts', 'proxyless') logger.warning('Our method does not directly search in ImageNet.') logger.warning('Only partial tests have been conducted, please use with caution.') import xnas.algorithms.RMINAS.teacher_model.fbresnet_imagenet.fbresnet as fbresnet @@ -93,8 +95,8 @@ def main(): ce_loss = torch.nn.CrossEntropyLoss(reduction='none').cuda() more_logits = network(more_data_X) _, indices = torch.topk(-ce_loss(more_logits, more_data_y).cpu().detach(), cfg.LOADER.BATCH_SIZE) - data_y = torch.Tensor([more_data_y[i] for i in indices]).long().cuda() - data_X = torch.Tensor([more_data_X[i].cpu().numpy() for i in indices]).cuda() + data_y = more_data_y.detach() + data_X = more_data_X.detach() with torch.no_grad(): feature_res = network.feature_extractor(data_X) @@ -107,6 +109,7 @@ def main(): loss_fun_log = torch.nn.CrossEntropyLoss().cuda() def train_arch(modelinfo): + flops = None if cfg.SPACE.NAME == 'infer_nb201': # get arch arch_config = { @@ -122,6 +125,12 @@ def main(): elif cfg.SPACE.NAME == 'nasbenchmacro': model = space_builder().cuda() optimizer = optimizer_builder("SGD", model.parameters()) + elif cfg.SPACE.NAME == 'proxyless': + model = space_builder(stage_width_list=[16, 24, 40, 80, 96, 192, 320],depth_param=modelinfo[:6],ks=modelinfo[6:27][modelinfo[6:27]>0],expand_ratio=modelinfo[27:][modelinfo[27:]>0],dropout_rate=0).cuda() + optimizer = optimizer_builder("SGD", model.parameters()) + with torch.no_grad(): + tensor = (torch.rand(1, 3, 224, 224).cuda(),) + flops = FlopCountAnalysis(model, tensor).total() # lr_scheduler = lr_scheduler_builder(optimizer) # nbm_trainer = OneShotTrainer( @@ -150,12 +159,14 @@ def main(): optimizer.step() epoch_losses.append(loss.detach().cpu().item()) if cur_epoch == cfg.OPTIM.MAX_EPOCH: - return loss.cpu().detach().numpy(), epoch_losses + + return loss.cpu().detach().numpy(), {'epoch_losses':epoch_losses, 'flops':flops} + trained_arch_darts, trained_loss = [], [] def train_procedure(sample): if cfg.SPACE.NAME == 'infer_nb201': - mixed_loss = train_arch(sample)[0] + mixed_loss, epoch_losses = train_arch(sample)[0] mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss trained_loss.append(mixed_loss) arch_arr = sampling.nb201genostr2array(api.arch(sample)) @@ -164,17 +175,25 @@ def main(): elif cfg.SPACE.NAME == 'infer_darts': sample_geno = geno_from_alpha(sampling.darts_sug2alpha(sample)) # type=Genotype trained_arch_darts.append(str(sample_geno)) - mixed_loss = train_arch(sample_geno)[0] + mixed_loss, epoch_losses = train_arch(sample_geno)[0] mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss trained_loss.append(mixed_loss) RFS.trained_arch.append({'arch':sample, 'loss':mixed_loss}) elif cfg.SPACE.NAME == 'nasbenchmacro': sample_geno = ''.join(sample.astype('str')) # type=Genotype trained_arch_darts.append((sample_geno)) - mixed_loss, epoch_losses = train_arch(sample) + mixed_loss, info = train_arch(sample) + mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss + trained_loss.append(mixed_loss) + RFS.trained_arch.append({'arch':sample, 'loss':mixed_loss,'gt':api[sample_geno]['mean_acc'],'losses':info["epoch_losses"]}) + elif cfg.SPACE.NAME == 'proxyless': + sample_geno = ''.join(sample.astype('str')) # type=Genotype + trained_arch_darts.append((sample_geno)) + mixed_loss, info = train_arch(sample) mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss trained_loss.append(mixed_loss) - RFS.trained_arch.append({'arch':sample, 'loss':mixed_loss,'gt':api[sample_geno]['mean_acc'],'losses':epoch_losses}) + RFS.trained_arch.append({'arch':sample, 'loss':mixed_loss,'gt':info["flops"],'losses':info["epoch_losses"]}) + logger.info("sample: {}, loss:{}".format(sample, mixed_loss)) @@ -185,21 +204,6 @@ def main(): for sample in warmup_samples: train_procedure(sample) RFS.Warmup() - logger.info('warmup time cost: {}'.format(str(time.time() - start_time))) - # with open('./rmi_nbm.pkl','wb') as f: - # pickle.dump(RFS.trained_arch,f) - # gt = np.array([_['gt'] for _ in RFS.trained_arch]) - # losses = np.array([_['losses'] for _ in RFS.trained_arch]) - # from scipy.stats import kendalltau - # kdts = [] - # for epoch in range(losses.shape[-1]): - # kdts.append(kendalltau(gt, -losses[:, epoch]).correlation) - # import matplotlib.pyplot as plt - # plt.plot(kdts) - # plt.xlabel('epoch') - # plt.ylabel('kdt') - # plt.savefig('rmi_nbm.png') - # sys.exit() # ====== RF Sampling ====== sampling_time = time.time() sampling_cnt= 0 @@ -231,6 +235,12 @@ def main(): # op_geno = reformat_DARTS(geno_from_alpha(op_alpha)) logger.info('Searched architecture@top50:\n{}'.format(str(op_sample))) print(api[op_sample]['mean_acc']) + elif cfg.SPACE.NAME == 'proxyless': + op_sample = RFS.optimal_arch(method='sum', top=100) + # op_alpha = torch.from_numpy(np.r_[op_sample, op_sample]) + # op_geno = reformat_DARTS(geno_from_alpha(op_alpha)) + logger.info('Searched architecture@top100:\n{}'.format(str(op_sample))) + print(api[op_sample]['mean_acc']) if __name__ == '__main__': main() diff --git a/xnas/algorithms/RMINAS/sampler/RF_sampling.py b/xnas/algorithms/RMINAS/sampler/RF_sampling.py index 500b490..083b6e4 100644 --- a/xnas/algorithms/RMINAS/sampler/RF_sampling.py +++ b/xnas/algorithms/RMINAS/sampler/RF_sampling.py @@ -35,7 +35,8 @@ class RF_suggest(): self.max_space = int(3**8) self.num_estimator = 30 self.spaces = list(api.keys()) - + elif self.space == 'proxyless': + self.num_estimator = 100 self.model = RandomForestClassifier(n_estimators=self.num_estimator,random_state=seed) def _update_lossthres(self): @@ -74,6 +75,8 @@ class RF_suggest(): return [self._single_sample() for _ in range(num_warmup)] elif self.space == 'nasbenchmacro': return [self._single_sample() for _ in range(num_warmup)] + elif self.space == 'proxyless': + return [self._single_sample() for _ in range(num_warmup)] def _single_sample(self, unique=True): if self.space == 'nasbench201': @@ -125,6 +128,28 @@ class RF_suggest(): else: numeric_choice = np.random.randint(3,size=8) return numeric_choice + elif self.space == 'proxyless': + def gen_sample(): + depth = np.array(np.random.randint(1, 4+1, size=5).tolist() + [1]) + anchors = depth+[0,4,8,12,16,20] + ks = np.random.choice([3,5,7], size=21) + expand_ratios = np.random.choice([3,6], size=21) + ed = 4 + for anchor in anchors: + ks[anchor:ed] = 0 + expand_ratios[anchor:ed] = 0 + ed += 4 + sample = np.concatenate([depth, ks, expand_ratios]) + return sample + if unique: + while True: + sample = gen_sample() + if sample.tobytes() not in self.sampled_history: + self.sampled_history.append(sample.tobytes()) + return sample + else: + sample = gen_sample() + return sample def Warmup(self): @@ -177,6 +202,18 @@ class RF_suggest(): for i in _sample_indexes: if self.spaces[i] not in chace_table: _sample_archs.append(np.array(list(self.spaces[i])).astype(int)) + elif self.space == 'proxyless': + _sample_batch = np.array([self._single_sample(unique=False).ravel() for _ in range(self.batch)]) + _tmp_trained_arch = [(i['arch'].tobytes()) for i in self.trained_arch] + _sample_archs = [] + for i in _sample_batch: + if (i).tobytes() not in _tmp_trained_arch: + _sample_archs.append(i) +# print("sample {} archs/batch, cost time: {}".format(len(_sample_archs), time.time()-start_time)) + best_id = np.argmax(self.model.predict_proba(_sample_archs)[:,1]) + best_arch = _sample_archs[best_id] + return best_arch + # _sample_batch = np.array([self._single_sample(unique=True).ravel() for _ in range(self.batch)]) # _tmp_trained_arch = [str(i['arch'].ravel()) for i in self.trained_arch] # _sample_archs = [] @@ -311,3 +348,12 @@ class RF_suggest(): op_arr = np.zeros((_tmp_np.size, 3)) op_arr[np.arange(_tmp_np.size),_tmp_np] = 1 return op_arr.argmax(-1) + elif self.space == 'proxyless': + assert method == 'sum', 'only sum is supported in mb.' + depth = estimate_archs[:, :6] + best_depth = np.eye(4)[depth].argmax(-1)+1 + ks = estimate_archs[:, 6:27]//2 # {3, 5, 7} + best_ks = np.eye(3)[ks].argmax(-1) * 2 + 3 + er = estimate_archs[:, 27:]//3 # {3, 6} + best_er = np.eye(2)[er].agrmax(-1) * 3 + 3 + return np.concatenate([best_depth, best_ks, best_er]) diff --git a/xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet.py b/xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet.py index 96ea1b4..cece2b1 100644 --- a/xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet.py +++ b/xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet.py @@ -7,7 +7,7 @@ import math import torch.utils.model_zoo as model_zoo import torch -WEIGHT_PATH = 'teacher_model/fbresnet_imagenet/fbresnet152.pth' +WEIGHT_PATH = './xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet152.pth' __all__ = ['FBResNet', #'fbresnet18', 'fbresnet34', 'fbresnet50', 'fbresnet101', diff --git a/xnas/algorithms/RMINAS/utils/random_data.py b/xnas/algorithms/RMINAS/utils/random_data.py index 2168754..13f4db9 100644 --- a/xnas/algorithms/RMINAS/utils/random_data.py +++ b/xnas/algorithms/RMINAS/utils/random_data.py @@ -1,5 +1,6 @@ import torch import random +import numpy as np from xnas.datasets.loader import get_normal_dataloader from xnas.datasets.imagenet import ImageFolder @@ -15,11 +16,8 @@ def get_random_data(batchsize, name): else: train_loader, _ = get_normal_dataloader(name, batchsize*16) - target_i = random.randint(0, len(train_loader)-1) - more_data_X, more_data_y = None, None - for i, (more_data_X, more_data_y) in enumerate(train_loader): - if i == target_i: - break - more_data_X = more_data_X.to(device) - more_data_y = more_data_y.to(device) + random_idxs = np.random.randint(0, len(train_loader.dataset), size=train_loader.batch_size) + (more_data_X, more_data_y) = zip(*[train_loader.dataset[idx] for idx in random_idxs]) + more_data_X = torch.stack(more_data_X, dim=0).to(device) + more_data_y = torch.Tensor(more_data_y).long().to(device) return more_data_X, more_data_y diff --git a/xnas/core/builder.py b/xnas/core/builder.py index ecc672b..0c8b1ef 100644 --- a/xnas/core/builder.py +++ b/xnas/core/builder.py @@ -24,6 +24,8 @@ from xnas.datasets.loader import construct_loader from xnas.runner.optimizer import optimizer_builder from xnas.runner.criterion import criterion_builder from xnas.runner.scheduler import lr_scheduler_builder +from xnas.spaces.ProxylessNAS.cnn import MobileNetV2 +from xnas.spaces.ProxylessNAS.super_proxyless import _SuperProxylessNASNets __all__ = [ 'construct_loader', @@ -72,6 +74,8 @@ SUPPORTED_SPACES = { "infer_nb201": _infer_NASBench201, "infer_spos": _infer_SPOS_CNN, "spos_nb201": _SPOS_nb201_CNN, + # "proxyless": _SuperProxylessNASNets, + "proxyless": MobileNetV2, } diff --git a/xnas/datasets/imagenet.py b/xnas/datasets/imagenet.py index 3a6bf73..8a10f03 100644 --- a/xnas/datasets/imagenet.py +++ b/xnas/datasets/imagenet.py @@ -47,9 +47,15 @@ class ImageFolder(): self.num_workers = cfg.LOADER.NUM_WORKERS if num_workers is None else num_workers self.pin_memory = cfg.LOADER.PIN_MEMORY if pin_memory is None else pin_memory self.shuffle = shuffle + # expand batch_size to support different number during training & validating + if isinstance(batch_size, int): + batch_size = [batch_size, batch_size] + elif batch_size is None: + batch_size = [256, 256] + assert len(batch_size) == len(split), "lengths of batch_size and split should be same." self.batch_size = batch_size if not self.use_val: - assert sum(self.split) == 1, "Summation of split should be 1" + assert sum(self._split) == 1, "Summation of split should be 1" self.msrc = None self.loader = torch.utils.data.DataLoader @@ -103,10 +109,9 @@ class ImageFolder(): for class_id in self._class_ids: cont_id = self._class_id_cont_id[class_id] train_im_dir = os.path.join(self._data_path, class_id) - for im_name in os.listdir(train_im_dir): + for im_name in filter(is_image_file, os.listdir(train_im_dir)): im_path = os.path.join(train_im_dir, im_name) - if is_image_file(im_path): - self._imdb.append({"im_path": im_path, "class": cont_id}) + self._imdb.append({"im_path": im_path, "class": cont_id}) logger.info("Number of images: {}".format(len(self._imdb))) logger.info("Number of classes: {}".format(len(self._class_ids))) else: @@ -209,8 +214,10 @@ class ImageList_torch(torch.utils.data.Dataset): random_flip=False): self._imdb = _list self.msrc = msrc - self._bgr_normalized_mean = _rgb_normalized_mean[::-1] - self._bgr_normalized_std = _rgb_normalized_std[::-1] + # self._bgr_normalized_mean = _rgb_normalized_mean[::-1] + # self._bgr_normalized_std = _rgb_normalized_std[::-1] + self._rgb_normalized_mean = _rgb_normalized_mean + self._rgb_normalized_std = _rgb_normalized_std self.crop = crop self.crop_size = crop_size self.min_crop = min_crop @@ -234,7 +241,7 @@ class ImageList_torch(torch.utils.data.Dataset): if self.random_flip: transforms.append(torch_transforms.RandomHorizontalFlip()) transforms.append(torch_transforms.ToTensor()) - transforms.append(torch_transforms.Normalize(mean=self._bgr_normalized_mean, std=self._bgr_normalized_std)) + transforms.append(torch_transforms.Normalize(mean=self._rgb_normalized_mean, std=self._rgb_normalized_std)) self.transform = torch_transforms.Compose(transforms) def __getitem__(self, index): diff --git a/xnas/spaces/OFA/ProxylessNet/cnn.py b/xnas/spaces/OFA/ProxylessNet/cnn.py index b8d54cb..195a92b 100644 --- a/xnas/spaces/OFA/ProxylessNet/cnn.py +++ b/xnas/spaces/OFA/ProxylessNet/cnn.py @@ -140,14 +140,14 @@ class MobileNetV2(ProxylessNASNet): width_mult=1.0, bn_param=(0.1, 1e-3), dropout_rate=0.2, - ks=None, - expand_ratio=None, + ks=None, # a list only include {3, 5, 7} + expand_ratio=None, # in proxyless space only 3 or 6 depth_param=None, stage_width_list=None, ): ks = 3 if ks is None else ks - expand_ratio = 6 if expand_ratio is None else expand_ratio + expand_ratio = [6]*6 if expand_ratio is None else expand_ratio input_channel = 32 last_channel = 1280 @@ -162,12 +162,12 @@ class MobileNetV2(ProxylessNASNet): inverted_residual_setting = [ # t, c, n, s [1, 16, 1, 1], - [expand_ratio, 24, 2, 2], - [expand_ratio, 32, 3, 2], - [expand_ratio, 64, 4, 2], - [expand_ratio, 96, 3, 1], - [expand_ratio, 160, 3, 2], - [expand_ratio, 320, 1, 1], + [None, 24, 2, 2], + [None, 32, 3, 2], + [None, 64, 4, 2], + [None, 96, 3, 1], + [None, 160, 3, 2], + [None, 320, 1, 1], ] if depth_param is not None: @@ -179,6 +179,10 @@ class MobileNetV2(ProxylessNASNet): for i in range(len(inverted_residual_setting)): inverted_residual_setting[i][1] = stage_width_list[i] + if expand_ratio is not None: + for i in range(len(inverted_residual_setting)): + inverted_residual_setting[i][0] = expand_ratio[i] + ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1) _pt = 0 @@ -201,7 +205,7 @@ class MobileNetV2(ProxylessNASNet): stride = s else: stride = 1 - if t == 1: + if t == 1: # only used for first block kernel_size = 3 else: kernel_size = ks[_pt] diff --git a/xnas/spaces/ProxylessNAS/cnn.py b/xnas/spaces/ProxylessNAS/cnn.py new file mode 100644 index 0000000..792c20b --- /dev/null +++ b/xnas/spaces/ProxylessNAS/cnn.py @@ -0,0 +1,268 @@ +import json +import numpy as np +import torch.nn as nn + +from xnas.spaces.OFA.ops import ( + set_layer_from_config, + MBConvLayer, + ConvLayer, + IdentityLayer, + LinearLayer, + ResidualBlock, + GlobalAvgPool2d, +) +from xnas.spaces.OFA.utils import val2list, make_divisible +from xnas.spaces.OFA.MobileNetV3.cnn import WSConv_Network + + +__all__ = ["proxyless_base", "ProxylessNASNet", "MobileNetV2"] + + +def proxyless_base( + net_config=None, + n_classes=None, + bn_param=None, + dropout_rate=None, +): + assert net_config is not None, "Please input a network config" + net_config_json = json.load(open(net_config, "r")) + + if n_classes is not None: + net_config_json["classifier"]["out_features"] = n_classes + if dropout_rate is not None: + net_config_json["classifier"]["dropout_rate"] = dropout_rate + + net = ProxylessNASNet.build_from_config(net_config_json) + if bn_param is not None: + net.set_bn_param(*bn_param) + + return net + + +class ProxylessNASNet(WSConv_Network): + def __init__(self, first_conv, blocks, feature_mix_layer, classifier): + super(ProxylessNASNet, self).__init__() + + self.first_conv = first_conv + self.blocks = nn.ModuleList(blocks) + self.feature_mix_layer = feature_mix_layer + self.global_avg_pool = GlobalAvgPool2d(keep_dim=False) + self.classifier = classifier + + def forward(self, x): + x = self.first_conv(x) + for block in self.blocks: + x = block(x) + if self.feature_mix_layer is not None: + x = self.feature_mix_layer(x) + x = self.global_avg_pool(x) + x = self.classifier(x) + return x + + + @property + def module_str(self): + _str = self.first_conv.module_str + "\n" + for block in self.blocks: + _str += block.module_str + "\n" + _str += self.feature_mix_layer.module_str + "\n" + _str += self.global_avg_pool.__repr__() + "\n" + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + "name": ProxylessNASNet.__name__, + "bn": self.get_bn_param(), + "first_conv": self.first_conv.config, + "blocks": [block.config for block in self.blocks], + "feature_mix_layer": None + if self.feature_mix_layer is None + else self.feature_mix_layer.config, + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + first_conv = set_layer_from_config(config["first_conv"]) + feature_mix_layer = set_layer_from_config(config["feature_mix_layer"]) + classifier = set_layer_from_config(config["classifier"]) + + blocks = [] + for block_config in config["blocks"]: + blocks.append(ResidualBlock.build_from_config(block_config)) + + net = ProxylessNASNet(first_conv, blocks, feature_mix_layer, classifier) + if "bn" in config: + net.set_bn_param(**config["bn"]) + else: + net.set_bn_param(momentum=0.1, eps=1e-3) + + return net + + def zero_last_gamma(self): + for m in self.modules(): + if isinstance(m, ResidualBlock): + if isinstance(m.conv, MBConvLayer) and isinstance( + m.shortcut, IdentityLayer + ): + m.conv.point_linear.bn.weight.data.zero_() + + @property + def grouped_block_index(self): + info_list = [] + block_index_list = [] + for i, block in enumerate(self.blocks[1:], 1): + if block.shortcut is None and len(block_index_list) > 0: + info_list.append(block_index_list) + block_index_list = [] + block_index_list.append(i) + if len(block_index_list) > 0: + info_list.append(block_index_list) + return info_list + + def load_state_dict(self, state_dict, **kwargs): + current_state_dict = self.state_dict() + + for key in state_dict: + if key not in current_state_dict: + assert ".mobile_inverted_conv." in key + new_key = key.replace(".mobile_inverted_conv.", ".conv.") + else: + new_key = key + current_state_dict[new_key] = state_dict[key] + super(ProxylessNASNet, self).load_state_dict(current_state_dict) + + +class MobileNetV2(ProxylessNASNet): + def __init__( + self, + n_classes=1000, + width_mult=1.0, + bn_param=(0.1, 1e-3), + dropout_rate=0.2, + ks=None, # a list only include {3, 5, 7} + expand_ratio=None, # in proxyless space only 3 or 6 + depth_param=None, + stage_width_list=None, + ): + + ks = 3 if ks is None else ks + expand_ratio = [6]*6 if expand_ratio is None else expand_ratio + + input_channel = 32 + last_channel = 1280 + + input_channel = make_divisible(input_channel * width_mult) + last_channel = ( + make_divisible(last_channel * width_mult) + if width_mult > 1.0 + else last_channel + ) + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [0, 24, 2, 2], + [0, 32, 3, 2], + [0, 64, 4, 2], + [0, 96, 3, 1], + [0, 160, 3, 2], + [0, 320, 1, 1], + ] + + if depth_param is not None: + assert len(depth_param) == 6 + # assert isinstance(depth_param, ) + for i in range(1, len(inverted_residual_setting) - 1): + inverted_residual_setting[i][2] = depth_param[i-1] + + if stage_width_list is not None: + assert len(stage_width_list) == 7 + for i in range(len(inverted_residual_setting)): + inverted_residual_setting[i][1] = stage_width_list[i] + + # if expand_ratio is not None: + # for i in range(len(inverted_residual_setting)): + # inverted_residual_setting[i][0] = expand_ratio[i] + + # ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1) + _pt = 0 + + self.feature_idx = np.cumsum(depth_param)[[1, 3]] + # first conv layer + first_conv = ConvLayer( + 3, + input_channel, + kernel_size=3, + stride=2, + use_bn=True, + act_func="relu6", + ops_order="weight_bn_act", + ) + # inverted residual blocks + blocks = [] + for t, c, n, s in inverted_residual_setting: + output_channel = make_divisible(c * width_mult) + for i in range(n): + if i == 0: + stride = s + else: + stride = 1 + if t == 1: # only used for first block + kernel_size = 3 + er = 1 + else: + kernel_size = ks[_pt].item() + er = expand_ratio[_pt].item() + _pt += 1 + mobile_inverted_conv = MBConvLayer( + in_channels=input_channel, + out_channels=output_channel, + kernel_size=kernel_size, + stride=stride, + expand_ratio=er, + ) + if stride == 1: + if input_channel == output_channel: + shortcut = IdentityLayer(input_channel, input_channel) + else: + shortcut = None + else: + shortcut = None + blocks.append(ResidualBlock(mobile_inverted_conv, shortcut)) + input_channel = output_channel + # 1x1_conv before global average pooling + feature_mix_layer = ConvLayer( + input_channel, + last_channel, + kernel_size=1, + use_bn=True, + act_func="relu6", + ops_order="weight_bn_act", + ) + + classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) + + super(MobileNetV2, self).__init__( + first_conv, blocks, feature_mix_layer, classifier + ) + + # set bn param + self.set_bn_param(*bn_param) + + def forward_with_features(self, x, *args, **kwargs): + x = self.first_conv(x) + features = [] + for i, block in enumerate(self.blocks): + if i in (self.feature_idx): + features.append(x) + x = block(x) + if self.feature_mix_layer is not None: + x = self.feature_mix_layer(x) + features.append(x) + assert len(features) == 3 + x = self.global_avg_pool(x) + logits = self.classifier(x) + return features, logits \ No newline at end of file -- 2.34.1 From 717240c60f55940af8e3efc5391c784c5052d51b Mon Sep 17 00:00:00 2001 From: xfey Date: Wed, 22 Jun 2022 20:32:09 +0800 Subject: [PATCH 3/7] update ImageNet dataloaders --- xnas/core/config.py | 7 +- xnas/datasets/auto_augment_tf.py | 402 +++++++++++++++++++++++++++ xnas/datasets/imagenet.py | 187 ++++--------- xnas/datasets/loader.py | 20 +- xnas/datasets/transforms_imagenet.py | 124 +++++++++ 5 files changed, 600 insertions(+), 140 deletions(-) create mode 100644 xnas/datasets/auto_augment_tf.py create mode 100644 xnas/datasets/transforms_imagenet.py diff --git a/xnas/core/config.py b/xnas/core/config.py index 7d09b97..8bd6a00 100644 --- a/xnas/core/config.py +++ b/xnas/core/config.py @@ -37,6 +37,9 @@ _C.LOADER.PIN_MEMORY = True # _C.LOADER.BATCH_SIZE = [256, 128] _C.LOADER.BATCH_SIZE = 256 +# augment type using by ImageNet only +# chosen from ['default', 'auto_augment_tf'] +_C.LOADER.TRANSFORM = "default" # ------------------------------------------------------------------------------------ # @@ -150,7 +153,9 @@ _C.TEST = CfgNode(new_allowed=True) _C.TEST.IM_SIZE = 224 -_C.TEST.BATCH_SIZE = 128 +# using specific batchsize for testing +# using search.batch_size if this value keeps -1 +_C.TEST.BATCH_SIZE = -1 diff --git a/xnas/datasets/auto_augment_tf.py b/xnas/datasets/auto_augment_tf.py new file mode 100644 index 0000000..c7f84e9 --- /dev/null +++ b/xnas/datasets/auto_augment_tf.py @@ -0,0 +1,402 @@ +""" Auto Augment +Implementation adapted from timm: https://github.com/rwightman/pytorch-image-models +""" + +import random +import math +from PIL import Image, ImageOps, ImageEnhance +import PIL + + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + +_HPARAMS_DEFAULT = dict( + translate_const=250, + img_mean=_FILL, +) + +_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop('resample', Image.NEAREST) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): + kwargs.pop('fillcolor') + kwargs['resample'] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs['resample']) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + bits_to_keep = max(1, bits_to_keep) # prevent all 0 images + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30. + level = _randomly_negate(level) + return (level,) + + +def _enhance_level_to_arg(level): + # range [0.1, 1.9] + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + + +def _shear_level_to_arg(level): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def _translate_abs_level_to_arg(level, translate_const): + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return (level,) + +def _translate_abs_level_to_arg2(level): + level = (level / _MAX_LEVEL) * float(_HPARAMS_DEFAULT['translate_const']) + level = _randomly_negate(level) + return (level,) + +def _translate_rel_level_to_arg(level): + # range [-0.45, 0.45] + level = (level / _MAX_LEVEL) * 0.45 + level = _randomly_negate(level) + return (level,) + + +# def level_to_arg(hparams): +# return { +# 'AutoContrast': lambda level: (), +# 'Equalize': lambda level: (), +# 'Invert': lambda level: (), +# 'Rotate': _rotate_level_to_arg, +# # FIXME these are both different from original impl as I believe there is a bug, +# # not sure what is the correct alternative, hence 2 options that look better +# 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8] +# 'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0] +# 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256] +# 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110] +# 'Color': _enhance_level_to_arg, +# 'Contrast': _enhance_level_to_arg, +# 'Brightness': _enhance_level_to_arg, +# 'Sharpness': _enhance_level_to_arg, +# 'ShearX': _shear_level_to_arg, +# 'ShearY': _shear_level_to_arg, +# 'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), +# 'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), +# 'TranslateXRel': lambda level: _translate_rel_level_to_arg(level), +# 'TranslateYRel': lambda level: _translate_rel_level_to_arg(level), +# } + + +NAME_TO_OP = { + 'AutoContrast': auto_contrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'Posterize': posterize, + 'Posterize2': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x_abs, + 'TranslateY': translate_y_abs, + 'TranslateXRel': translate_x_rel, + 'TranslateYRel': translate_y_rel, +} + + +def pass_fn(input): + return () + + +def _conversion0(input): + return (int((input / _MAX_LEVEL) * 4) + 4,) + + +def _conversion1(input): + return (4 - int((input / _MAX_LEVEL) * 4),) + + +def _conversion2(input): + return (int((input / _MAX_LEVEL) * 256),) + + +def _conversion3(input): + return (int((input / _MAX_LEVEL) * 110),) + + +class AutoAugmentOp: + def __init__(self, name, prob, magnitude, hparams={}): + self.aug_fn = NAME_TO_OP[name] + # self.level_fn = level_to_arg(hparams)[name] + if name == 'AutoContrast' or name == 'Equalize' or name == 'Invert': + self.level_fn = pass_fn + elif name == 'Rotate': + self.level_fn = _rotate_level_to_arg + elif name == 'Posterize': + self.level_fn = _conversion0 + elif name == 'Posterize2': + self.level_fn = _conversion1 + elif name == 'Solarize': + self.level_fn = _conversion2 + elif name == 'SolarizeAdd': + self.level_fn = _conversion3 + elif name == 'Color' or name == 'Contrast' or name == 'Brightness' or name == 'Sharpness': + self.level_fn = _enhance_level_to_arg + elif name == 'ShearX' or name == 'ShearY': + self.level_fn = _shear_level_to_arg + elif name == 'TranslateX' or name == 'TranslateY': + self.level_fn = _translate_abs_level_to_arg2 + elif name == 'TranslateXRel' or name == 'TranslateYRel': + self.level_fn = _translate_rel_level_to_arg + else: + print("{} not recognized".format({})) + self.prob = prob + self.magnitude = magnitude + # If std deviation of magnitude is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from normal dist + # with mean magnitude and std-dev of magnitude_std. + # NOTE This is being tested as it's not in paper or reference impl. + self.magnitude_std = 0.5 # FIXME add arg/hparam + self.kwargs = { + 'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL, + 'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION + } + + def __call__(self, img): + if self.prob < random.random(): + return img + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) + level_args = self.level_fn(magnitude) + return self.aug_fn(img, *level_args, **self.kwargs) + + +def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT): + # ImageNet policy from TPU EfficientNet impl, cannot find + # a paper reference. + policy = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], + ] + pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy] + return pc + + +def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT): + # ImageNet policy from https://arxiv.org/abs/1805.09501 + policy = [ + [('Posterize', 0.4, 8), ('Rotate', 0.6, 9)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + [('Posterize', 0.6, 7), ('Posterize', 0.6, 6)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], + [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], + [('Posterize', 0.8, 5), ('Equalize', 1.0, 2)], + [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], + [('Equalize', 0.6, 8), ('Posterize', 0.4, 6)], + [('Rotate', 0.8, 8), ('Color', 0.4, 0)], + [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], + [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Rotate', 0.8, 8), ('Color', 1.0, 2)], + [('Color', 0.8, 8), ('Solarize', 0.8, 7)], + [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], + [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], + [('Color', 0.4, 0), ('Equalize', 0.6, 3)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + ] + pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy] + return pc + + +def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT): + if name == 'original': + return auto_augment_policy_original(hparams) + elif name == 'v0': + return auto_augment_policy_v0(hparams) + else: + print("Unknown auto_augmentation policy {}".format(name)) + raise AssertionError() + + +class AutoAugment: + + def __init__(self, policy): + self.policy = policy + + def __call__(self, img): + sub_policy = random.choice(self.policy) + for op in sub_policy: + img = op(img) + return img diff --git a/xnas/datasets/imagenet.py b/xnas/datasets/imagenet.py index f565f91..6968b04 100644 --- a/xnas/datasets/imagenet.py +++ b/xnas/datasets/imagenet.py @@ -1,100 +1,67 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - """ImageNet dataset.""" -import math import os import re import numpy as np import torch import torch.utils.data -import torchvision.transforms as torch_transforms from PIL import Image from torch.utils.data.distributed import DistributedSampler import xnas.logger.logging as logging from xnas.core.config import cfg -from xnas.datasets.transforms import MultiSizeRandomCrop +from xnas.datasets.transforms_imagenet import get_data_transform logger = logging.get_logger(__name__) class ImageFolder(): - def __init__( - self, - datapath, - batch_size, - split=None, - use_val=False, - dataset_name='imagenet', - _rgb_normalized_mean=None, - _rgb_normalized_std=None, - transforms=None, - num_workers=None, - pin_memory=None, - shuffle=True - ): + """New ImageFolder + Support ImageNet only currently. + """ + def __init__(self, datapath, batch_size, split=None, use_val=False, augment_type='default', **kwargs): datapath = './data/imagenet/' if not datapath else datapath assert os.path.exists(datapath), "Data path '{}' not found".format(datapath) self.use_val = use_val - self.data_path, self.split, self.dataset_name = datapath, split, dataset_name - self.rgb_normalized_mean, self.rgb_normalized_std = _rgb_normalized_mean, _rgb_normalized_std - self.num_workers = cfg.LOADER.NUM_WORKERS if num_workers is None else num_workers - self.pin_memory = cfg.LOADER.PIN_MEMORY if pin_memory is None else pin_memory - self.shuffle = shuffle + self.data_path = datapath + self.split = split self.batch_size = batch_size - if not self.use_val: - assert sum(self.split) == 1, "Summation of split should be 1" - - self.msrc = None - self.loader = torch.utils.data.DataLoader - # self.collate_fn = None + self.num_workers = cfg.LOADER.NUM_WORKERS + self.pin_memory = cfg.LOADER.PIN_MEMORY + self.augment_type = augment_type + self.kwargs = kwargs - if transforms is None: - im_size = cfg.SEARCH.IM_SIZE if len(cfg.SEARCH.MULTI_SIZES)==0 else cfg.SEARCH.MULTI_SIZES - self.transforms = [{'crop': 'random', 'crop_size': im_size, 'min_crop': 0.08, 'random_flip': True}, - {'crop': 'center', 'crop_size': cfg.TEST.IM_SIZE, 'min_crop': -1, 'random_flip': False}] # NOTE: min_crop is not used here. - else: - self.transforms = transforms if not self.use_val: - assert len(self.transforms) == len(self.split), "Length of split and transforms should be equal" - else: - assert len(self.transforms) == 2 + assert sum(self.split) == 1, "Summation of split should be 1." - # Check if using multisize_random_crop - if len(cfg.SEARCH.MULTI_SIZES): + # setting default loader if not using MultiSizeRandomCrop + if len(cfg.SEARCH.MULTI_SIZES) == 0: + self.loader = torch.utils.data.DataLoader + else: from xnas.datasets.utils.msrc_loader import msrc_DataLoader - self.msrc = MultiSizeRandomCrop(cfg.SEARCH.MULTI_SIZES) self.loader = msrc_DataLoader - logger.info("Using Random MultiSize Crop, continuous={} candidate im_sizes={}".format(self.msrc.CONTINUOUS, self.msrc.CANDIDATE_SIZES)) + logger.info("Using MultiSize RandomCrop, continuous={} candidate im_sizes={}".format(self.msrc.CONTINUOUS, self.msrc.CANDIDATE_SIZES)) - # Read all dataset + # Acquiring transforms + logger.info("Constructing transforms") + self.train_transform, self.test_transform = self._build_transfroms() + + # Read all datasets logger.info("Constructing ImageFolder") self._construct_imdb() - + def _construct_imdb(self): # Images are stored per class in subdirs (format: n) if not self.use_val: split_files = os.listdir(self.data_path) else: split_files = os.listdir(os.path.join(self.data_path, "train")) - if self.dataset_name == "imagenet": - # imagenet format folder names - self._class_ids = sorted( - f for f in split_files if re.match(r"^n[0-9]+$", f)) - self.rgb_normalized_mean = [0.485, 0.456, 0.406] - self.rgb_normalized_std = [0.229, 0.224, 0.225] - elif self.dataset_name == 'custom': - self._class_ids = sorted( - f for f in split_files if not f[0] == '.') - else: - raise NotImplementedError + # imagenet format folder names + self._class_ids = sorted( + f for f in split_files if re.match(r"^n[0-9]+$", f)) # Map class ids to contiguous ids self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)} @@ -130,7 +97,11 @@ class ImageFolder(): logger.info("Number of classes: {}".format(len(self._class_ids))) logger.info("Number of TRAIN images: {}".format(len(self._train_imdb))) logger.info("Number of VAL images: {}".format(len(self._val_imdb))) - + + def _build_transfroms(self): + # KWARGS for 'auto_augment_tf': policy='v0', interpolation='bilinear' + return get_data_transform(augment=self.augment_type, **self.kwargs) + def generate_data_loader(self): if not self.use_val: indices = list(range(len(self._imdb))) @@ -144,16 +115,17 @@ class ImageFolder(): _current_index = int(len(self._imdb) * _current_partition) _current_indices = indices[pre_index: _current_index] assert not len(_current_indices) == 0, "The length of indices is zero!" - dataset = ImageList_torch([self._imdb[j] for j in _current_indices], - self.msrc, # add support for multisize_random_crop - _rgb_normalized_mean=self.rgb_normalized_mean, - _rgb_normalized_std=self.rgb_normalized_std, - **self.transforms[i]) + dataset = ImageList_torch( + [self._imdb[j] for j in _current_indices], + # using the first split only as training dataset + transform=self.train_transform if i==0 else self.test_transform + ) sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None loader = self.loader(dataset, batch_size=self.batch_size[i], shuffle=(False if sampler else True), sampler=sampler, + drop_last=(True if i==0 else False), num_workers=self.num_workers, pin_memory=self.pin_memory) data_loaders.append(loader) @@ -161,82 +133,37 @@ class ImageFolder(): pre_index = _current_index return data_loaders else: - train_dataset = ImageList_torch( - self._train_imdb, - self.msrc, - _rgb_normalized_mean=self.rgb_normalized_mean, - _rgb_normalized_std=self.rgb_normalized_std, - **self.transforms[0] - ) - sampler = DistributedSampler(train_dataset) if cfg.NUM_GPUS > 1 else None + train_dataset = ImageList_torch(self._train_imdb, self.train_transform) + train_sampler = DistributedSampler(train_dataset) if cfg.NUM_GPUS > 1 else None train_loader = self.loader(train_dataset, - batch_size=self.batch_size[0], - shuffle=(False if sampler else True), - sampler=sampler, - num_workers=self.num_workers, - pin_memory=self.pin_memory) - val_dataset = ImageList_torch( - self._val_imdb, - self.msrc, - _rgb_normalized_mean=self.rgb_normalized_mean, - _rgb_normalized_std=self.rgb_normalized_std, - **self.transforms[1] - ) - sampler = DistributedSampler(val_dataset) if cfg.NUM_GPUS > 1 else None + batch_size=self.batch_size[0], + shuffle=(False if train_sampler else True), + sampler=train_sampler, + drop_last=True, + num_workers=self.num_workers, + pin_memory=self.pin_memory) + + val_dataset = ImageList_torch(self._val_imdb, self.test_transform) + val_sampler = DistributedSampler(val_dataset) if cfg.NUM_GPUS > 1 else None valid_loader = self.loader(val_dataset, - batch_size=self.batch_size[1], - shuffle=(False if sampler else True), - sampler=sampler, - num_workers=self.num_workers, - pin_memory=self.pin_memory) + batch_size=self.batch_size[1], + shuffle=(False if val_sampler else True), + sampler=val_sampler, + drop_last=False, + num_workers=self.num_workers, + pin_memory=self.pin_memory) return [train_loader, valid_loader] - class ImageList_torch(torch.utils.data.Dataset): ''' ImageList dataloader with torch backends From https://github.com/pytorch/vision/issues/81 ''' - def __init__( - self, - _list, - msrc=None, - _rgb_normalized_mean=None, - _rgb_normalized_std=None, - crop='random', - crop_size=224, - min_crop=0.08, - random_flip=False): - self._imdb = _list - self.msrc = msrc - self._rgb_normalized_mean = _rgb_normalized_mean - self._rgb_normalized_std = _rgb_normalized_std - self.crop = crop - self.crop_size = crop_size - self.min_crop = min_crop - self.random_flip = random_flip + def __init__(self, list, transform): + self._imdb = list + self.transform = transform self.loader = pil_loader - self._construct_transforms() - - def _construct_transforms(self): - transforms = [] - if self.crop == "random": - if isinstance(self.crop_size, int): - transforms.append(torch_transforms.RandomResizedCrop(self.crop_size, scale=(self.min_crop, 1.0))) - elif isinstance(self.crop_size, list): - # using MultiSizeRandomCrop - transforms.append(self.msrc) - elif self.crop == "center": - transforms.append(torch_transforms.Resize(math.ceil(self.crop_size / 0.875))) # assert crop_size==224 - transforms.append(torch_transforms.CenterCrop(self.crop_size)) - # TODO: color augmentation support - # transforms.append(torch_transforms.ColorJitter(brightness=0.4, contrast=0.4,saturation=0.4, hue=0.2)) - if self.random_flip: - transforms.append(torch_transforms.RandomHorizontalFlip()) - transforms.append(torch_transforms.ToTensor()) - transforms.append(torch_transforms.Normalize(mean=self._rgb_normalized_mean, std=self._rgb_normalized_std)) - self.transform = torch_transforms.Compose(transforms) def __getitem__(self, index): impath = self._imdb[index]["im_path"] diff --git a/xnas/datasets/loader.py b/xnas/datasets/loader.py index d0266f3..4e2b467 100644 --- a/xnas/datasets/loader.py +++ b/xnas/datasets/loader.py @@ -26,6 +26,7 @@ def construct_loader( cutout_length=0, use_classes=None, transforms=None, + **kwargs ): """Construct NAS dataloaders with train&valid subsets.""" @@ -52,10 +53,11 @@ def construct_loader( train_data, _ = get_data(name, datapath, cutout_length, use_classes=use_classes, transforms=transforms) return split_dataloader(train_data, batch_size, split) elif name in IMAGEFOLDER_FORMAT: - assert cfg.LOADER.USE_VAL is False, "do not get normal dataloaders." + assert cfg.LOADER.USE_VAL is False, "do not using VAL dataset." + aug_type = cfg.LOADER.TRANSFORM return ImageFolder( # using path of training data of ImageNet as `datapath` datapath, batch_size=batch_size, split=split, - use_val=False, transforms=transforms, + use_val=False, augment_type=aug_type, **kwargs ).generate_data_loader() else: print("dataset not supported.") @@ -140,12 +142,13 @@ def get_normal_dataloader( cutout_length=0, use_classes=None, transforms=None, + **kwargs ): name=cfg.LOADER.DATASET if name is None else name train_batch=cfg.LOADER.BATCH_SIZE if train_batch is None else train_batch name=cfg.LOADER.DATASET - root=cfg.LOADER.DATAPATH - test_batch=cfg.TEST.BATCH_SIZE + datapath=cfg.LOADER.DATAPATH + test_batch=cfg.LOADER.BATCH_SIZE if cfg.TEST.BATCH_SIZE == -1 else cfg.TEST.BATCH_SIZE assert (name in SUPPORTED_DATASETS) or (name in IMAGEFOLDER_FORMAT), "dataset not supported." assert isinstance(train_batch, int), "normal dataloader using single training batch-size, not list." @@ -155,7 +158,7 @@ def get_normal_dataloader( if name in SUPPORTED_DATASETS: # get normal dataloaders with train&test subsets. - train_data, test_data = get_data(name, root, cutout_length, use_classes=use_classes, transforms=transforms) + train_data, test_data = get_data(name, datapath, cutout_length, use_classes=use_classes, transforms=transforms) train_loader = data.DataLoader( dataset=train_data, @@ -174,13 +177,12 @@ def get_normal_dataloader( return train_loader, test_loader elif name in IMAGEFOLDER_FORMAT: assert cfg.LOADER.USE_VAL is True, "getting normal dataloader." + aug_type = cfg.LOADER.TRANSFORM return ImageFolder( # using path of training data of ImageNet as `datapath` - root, batch_size=[train_batch, test_batch], - use_val=True, - transforms=transforms, + datapath, batch_size=[train_batch, test_batch], + use_val=True, augment_type=aug_type, **kwargs ).generate_data_loader() - def split_dataloader(data_, batch_size, split): assert 0 not in split, "illegal split list with zero." assert sum(split) == 1, "summation of split should be one." diff --git a/xnas/datasets/transforms_imagenet.py b/xnas/datasets/transforms_imagenet.py new file mode 100644 index 0000000..8b6ee7b --- /dev/null +++ b/xnas/datasets/transforms_imagenet.py @@ -0,0 +1,124 @@ +import math +import torch +from PIL import Image +import torchvision.transforms as transforms + +from xnas.core.config import cfg +from xnas.datasets.auto_augment_tf import auto_augment_policy, AutoAugment + + +IMAGENET_RGB_MEAN = [0.485, 0.456, 0.406] +IMAGENET_RGB_STD = [0.229, 0.224, 0.225] + + +def get_data_transform(augment, **kwargs): + if len(cfg.SEARCH.MULTI_SIZES)==0: + # using single image_size for training + train_crop_size = cfg.SEARCH.IM_SIZE + else: + # using MultiSize_RandomCrop + train_crop_size = cfg.SEARCH.MULTI_SIZES + min_train_scale = 0.08 + test_scale = math.ceil(cfg.TEST.IM_SIZE / 0.875) # 224 / 0.875 = 256 + test_crop_size = cfg.TEST.IM_SIZE # do not crop and using 224 by default. + + interpolation = transforms.InterpolationMode.BICUBIC + if 'interpolation' in kwargs.keys() and kwargs['interpolation'] == 'bilinear': + interpolation = transforms.InterpolationMode.BILINEAR + + da_args = { + 'train_crop_size': train_crop_size, + 'train_min_scale': min_train_scale, + 'test_scale': test_scale, + 'test_crop_size': test_crop_size, + 'interpolation': interpolation, + } + + if augment == 'default': + return build_default_transform(**da_args) + elif augment == 'auto_augment_tf': + policy = 'v0' if 'policy' not in kwargs.keys() else kwargs['policy'] + return build_imagenet_auto_augment_tf_transform(policy=policy, **da_args) + else: + raise ValueError(augment) + + +def get_normalize(): + return transforms.Normalize( + mean=torch.Tensor(IMAGENET_RGB_MEAN), + std=torch.Tensor(IMAGENET_RGB_STD), + ) + + +def get_randomResizedCrop(train_crop_size=224, train_min_scale=0.08, interpolation=transforms.InterpolationMode.BICUBIC): + if isinstance(train_crop_size, int): + return transforms.RandomResizedCrop(train_crop_size, scale=(train_min_scale, 1.0), interpolation=interpolation) + elif isinstance(train_crop_size, list): + from xnas.datasets.transforms import MultiSizeRandomCrop + msrc = MultiSizeRandomCrop(train_crop_size) + return msrc + else: + raise TypeError(train_crop_size) + + +def build_default_transform( + train_crop_size=224, train_min_scale=0.08, test_scale=256, test_crop_size=224, interpolation=transforms.InterpolationMode.BICUBIC +): + normalize = get_normalize() + train_crop_transform = get_randomResizedCrop( + train_crop_size, train_min_scale, interpolation + ) + train_transform = transforms.Compose( + [ + # transforms.RandomResizedCrop(train_crop_size, interpolation=interpolation), + train_crop_transform, + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + ) + test_transform = transforms.Compose( + [ + transforms.Resize(test_scale, interpolation=interpolation), + transforms.CenterCrop(test_crop_size), + transforms.ToTensor(), + normalize, + ] + ) + return train_transform, test_transform + + +def build_imagenet_auto_augment_tf_transform( + policy='v0', train_crop_size=224, train_min_scale=0.08, test_scale=256, test_crop_size=224, interpolation=transforms.InterpolationMode.BICUBIC +): + + normalize = get_normalize() + img_size = train_crop_size + aa_params = { + "translate_const": int(img_size * 0.45), + "img_mean": tuple(round(x) for x in IMAGENET_RGB_MEAN), + } + + aa_policy = AutoAugment(auto_augment_policy(policy, aa_params)) + train_crop_transform = get_randomResizedCrop( + train_crop_size, train_min_scale, interpolation + ) + train_transform = transforms.Compose( + [ + # transforms.RandomResizedCrop(train_crop_size, interpolation=interpolation), + train_crop_transform, + transforms.RandomHorizontalFlip(), + aa_policy, + transforms.ToTensor(), + normalize, + ] + ) + test_transform = transforms.Compose( + [ + transforms.Resize(test_scale, interpolation=interpolation), + transforms.CenterCrop(test_crop_size), + transforms.ToTensor(), + normalize, + ] + ) + return train_transform, test_transform -- 2.34.1 From 0fda89c83fe001171a94f39996ed0b6fa50a9400 Mon Sep 17 00:00:00 2001 From: xfey Date: Wed, 22 Jun 2022 20:34:42 +0800 Subject: [PATCH 4/7] modify evaluation method Median -> Avg --- scripts/train/DARTS.py | 6 +++--- xnas/runner/scheduler.py | 6 ++++-- xnas/runner/trainer.py | 14 +++++++------- xnas/runner/trainer_spos.py | 17 ++++++++--------- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/scripts/train/DARTS.py b/scripts/train/DARTS.py index 45f8a57..268abec 100644 --- a/scripts/train/DARTS.py +++ b/scripts/train/DARTS.py @@ -97,9 +97,9 @@ class Darts_Retrainer(Trainer): self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.iter_tic() - top1_err = self.test_meter.mb_top1_err.get_win_median() - self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch) - self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch) + top1_err = self.test_meter.mb_top1_err.get_win_avg() + self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch) + self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch) # Log epoch stats self.test_meter.log_epoch_stats(cur_epoch) self.test_meter.reset() diff --git a/xnas/runner/scheduler.py b/xnas/runner/scheduler.py index 02b9c2d..b77a9b5 100644 --- a/xnas/runner/scheduler.py +++ b/xnas/runner/scheduler.py @@ -85,14 +85,15 @@ class GradualWarmupScheduler(_LRScheduler): def _calc_learning_rate( init_lr, n_epochs, epoch, n_iter=None, iter=0, ): - if cfg.SEARCH.LOSS_FUN.startswith("cross_entropy"): + if cfg.OPTIM.LR_POLICY == "cos": t_total = n_epochs * n_iter t_cur = epoch * n_iter + iter lr = 0.5 * init_lr * (1 + math.cos(math.pi * t_cur / t_total)) else: - raise ValueError("do not support: {}".format(cfg.SEARCH.LOSS_FUN)) + raise ValueError("do not support: {}".format(cfg.OPTIM.LR_POLICY)) return lr + def _warmup_adjust_learning_rate( init_lr, n_epochs, epoch, n_iter, iter=0, warmup_lr=0 ): @@ -102,6 +103,7 @@ def _warmup_adjust_learning_rate( new_lr = T_cur / t_total * (init_lr - warmup_lr) + warmup_lr return new_lr + def adjust_learning_rate_per_batch(epoch, n_iter=None, iter=0, warmup=False): """adjust learning of a given optimizer and return the new learning rate""" diff --git a/xnas/runner/trainer.py b/xnas/runner/trainer.py index a78089e..9146658 100644 --- a/xnas/runner/trainer.py +++ b/xnas/runner/trainer.py @@ -114,9 +114,9 @@ class Trainer(Recorder): self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.iter_tic() - top1_err = self.test_meter.mb_top1_err.get_win_median() - self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch) - self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch) + top1_err = self.test_meter.mb_top1_err.get_win_avg() + self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch) + self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch) # Log epoch stats self.test_meter.log_epoch_stats(cur_epoch) self.test_meter.reset() @@ -349,9 +349,9 @@ class OneShotTrainer(Trainer): self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.iter_tic() - top1_err = self.test_meter.mb_top1_err.get_win_median() - self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch) - self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch) + top1_err = self.test_meter.mb_top1_err.get_win_avg() + self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch) + self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch) # Log epoch stats self.test_meter.log_epoch_stats(cur_epoch) self.test_meter.reset() @@ -372,7 +372,7 @@ class OneShotTrainer(Trainer): top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) top1_err, top5_err = top1_err.item(), top5_err.item() self.evaluate_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) - top1_err = self.evaluate_meter.mb_top1_err.get_win_median() + top1_err = self.evaluate_meter.mb_top1_err.get_win_avg() # self.evaluate_sampler.record(choice, top1_err) self.evaluate_meter.reset() return top1_err diff --git a/xnas/runner/trainer_spos.py b/xnas/runner/trainer_spos.py index 5be33f0..50e9835 100644 --- a/xnas/runner/trainer_spos.py +++ b/xnas/runner/trainer_spos.py @@ -114,9 +114,9 @@ class Trainer(Recorder): self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.iter_tic() - top1_err = self.test_meter.mb_top1_err.get_win_median() - self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch) - self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch) + top1_err = self.test_meter.mb_top1_err.get_win_avg() + self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch) + self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch) # Log epoch stats self.test_meter.log_epoch_stats(cur_epoch) self.test_meter.reset() @@ -381,10 +381,9 @@ class OneShotTrainer(Trainer): self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.iter_tic() - top1_err = self.test_meter.mb_top1_err.get_win_median() - top1_err_avg = self.test_meter.mb_top1_err.get_global_avg() - self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch) - self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch) + top1_err = self.test_meter.mb_top1_err.get_win_avg() + self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch) + self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch) # Log epoch stats self.test_meter.log_epoch_stats(cur_epoch) self.test_meter.reset() @@ -392,7 +391,7 @@ class OneShotTrainer(Trainer): if self.best_err > top1_err: self.best_err = top1_err self.saving(cur_epoch, best=True) - return top1_err_avg + return top1_err @torch.no_grad() def evaluate_epoch(self, sample): @@ -405,7 +404,7 @@ class OneShotTrainer(Trainer): top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) top1_err, top5_err = top1_err.item(), top5_err.item() self.evaluate_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) - top1_err = self.evaluate_meter.mb_top1_err.get_win_median() + top1_err = self.evaluate_meter.mb_top1_err.get_win_avg() # self.evaluate_sampler.record(choice, top1_err) self.evaluate_meter.reset() return top1_err -- 2.34.1 From e224236db4ca484aa9b9fe52a659b1e6b8882641 Mon Sep 17 00:00:00 2001 From: xfey Date: Wed, 22 Jun 2022 21:00:26 +0800 Subject: [PATCH 5/7] add BigNAS and fix bugs --- configs/search/BigNAS/eval.yaml | 89 ++++ configs/search/BigNAS/search.yaml | 133 ++++++ configs/search/BigNAS/train.yaml | 95 ++++ examples/search/OFA/train_supernet.sh | 2 +- scripts/search/OFA/eval_supernet.py | 10 +- scripts/search/OFA/train_supernet.py | 16 +- tests/ofa_matrices_test.py | 2 +- xnas/runner/criterion.py | 40 +- xnas/runner/optimizer.py | 1 - xnas/spaces/BigNAS/cnn.py | 653 ++++++++++++++++++++++++++ xnas/spaces/BigNAS/dynamic_layers.py | 331 +++++++++++++ xnas/spaces/BigNAS/dynamic_ops.py | 181 +++++++ xnas/spaces/BigNAS/ops.py | 133 ++++++ xnas/spaces/BigNAS/utils.py | 134 ++++++ xnas/spaces/OFA/dynamic_ops.py | 34 +- xnas/spaces/OFA/ops.py | 106 ++++- xnas/spaces/OFA/utils.py | 25 + 17 files changed, 1921 insertions(+), 64 deletions(-) create mode 100644 configs/search/BigNAS/eval.yaml create mode 100644 configs/search/BigNAS/search.yaml create mode 100644 configs/search/BigNAS/train.yaml create mode 100644 xnas/spaces/BigNAS/cnn.py create mode 100644 xnas/spaces/BigNAS/dynamic_layers.py create mode 100644 xnas/spaces/BigNAS/dynamic_ops.py create mode 100644 xnas/spaces/BigNAS/ops.py create mode 100644 xnas/spaces/BigNAS/utils.py diff --git a/configs/search/BigNAS/eval.yaml b/configs/search/BigNAS/eval.yaml new file mode 100644 index 0000000..65ca309 --- /dev/null +++ b/configs/search/BigNAS/eval.yaml @@ -0,0 +1,89 @@ +NUM_GPUS: 4 +RNG_SEED: 2 +SPACE: + NAME: 'bignas' +LOADER: + DATASET: 'imagenet' + NUM_CLASSES: 1000 + BATCH_SIZE: 128 + NUM_WORKERS: 4 + USE_VAL: True + TRANSFORM: "auto_augment_tf" +SEARCH: + IM_SIZE: 224 +BIGNAS: + BN_MOMENTUM: 0.1 + BN_EPS: 1.e-5 + POST_BN_CALIBRATION_BATCH_NUM: 64 + ACTIVE_SUBNET: # subnet for evaluation + RESOLUTION: 192 + WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792] + KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3] + EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6] + DEPTH: [1, 3, 3, 3, 3, 3, 1] + SUPERNET_CFG: + use_v3_head: True + resolutions: [192, 224, 256, 288] + first_conv: + c: [16, 24] + act_func: 'swish' + s: 2 + mb1: + c: [16, 24] + d: [1, 2] + k: [3, 5] + t: [1] + s: 1 + act_func: 'swish' + se: False + mb2: + c: [24, 32] + d: [3, 4, 5] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: False + mb3: + c: [32, 40] + d: [3, 4, 5, 6] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: True + mb4: + c: [64, 72] + d: [3, 4, 5, 6] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: False + mb5: + c: [112, 120, 128] + d: [3, 4, 5, 6, 7, 8] + k: [3, 5] + t: [4, 5, 6] + s: 1 + act_func: 'swish' + se: True + mb6: + c: [192, 200, 208, 216] + d: [3, 4, 5, 6, 7, 8] + k: [3, 5] + t: [6] + s: 2 + act_func: 'swish' + se: True + mb7: + c: [216, 224] + d: [1, 2] + k: [3, 5] + t: [6] + s: 1 + act_func: 'swish' + se: True + last_conv: + c: [1792, 1984] + act_func: 'swish' diff --git a/configs/search/BigNAS/search.yaml b/configs/search/BigNAS/search.yaml new file mode 100644 index 0000000..bc8a928 --- /dev/null +++ b/configs/search/BigNAS/search.yaml @@ -0,0 +1,133 @@ +NUM_GPUS: 1 +RNG_SEED: 2 +SPACE: + NAME: 'bignas' +LOADER: + DATASET: 'imagenet' + NUM_CLASSES: 1000 + BATCH_SIZE: 128 + NUM_WORKERS: 8 + USE_VAL: True + TRANSFORM: "auto_augment_tf" +SEARCH: + IM_SIZE: 224 + WEIGHTS: "exp/search/test/checkpoints/best_model_epoch_0009.pyth" +BIGNAS: + CONSTRAINT_FLOPS: 6.e+8 # 600M + NUM_MUTATE: 200 + BN_MOMENTUM: 0.1 + BN_EPS: 1.e-5 + POST_BN_CALIBRATION_BATCH_NUM: 64 + # ACTIVE_SUBNET: # subnet for evaluation + # RESOLUTION: 192 + # WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792] + # KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3] + # EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6] + # DEPTH: [1, 3, 3, 3, 3, 3, 1] + SEARCH_CFG_SETS: + resolutions: [224, 256] + first_conv: + c: [16] + mb1: + c: [16] + d: [2] + k: [3] + t: [1] + mb2: + c: [24] + d: [3] + k: [3] + t: [5] + mb3: + c: [32] + d: [4] + k: [3] + t: [5] + mb4: + c: [64] + d: [5] + k: [3] + t: [5] + mb5: + c: [120] + d: [6] + k: [3] + t: [5] + mb6: + c: [192] + d: [6] + k: [3, 5] + t: [6] + mb7: + c: [216] + d: [2] + k: [3] + t: [6] + last_conv: + c: [1792] + SUPERNET_CFG: + use_v3_head: True + resolutions: [192, 224, 256, 288] + first_conv: + c: [16, 24] + act_func: 'swish' + s: 2 + mb1: + c: [16, 24] + d: [1, 2] + k: [3, 5] + t: [1] + s: 1 + act_func: 'swish' + se: False + mb2: + c: [24, 32] + d: [3, 4, 5] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: False + mb3: + c: [32, 40] + d: [3, 4, 5, 6] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: True + mb4: + c: [64, 72] + d: [3, 4, 5, 6] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: False + mb5: + c: [112, 120, 128] + d: [3, 4, 5, 6, 7, 8] + k: [3, 5] + t: [4, 5, 6] + s: 1 + act_func: 'swish' + se: True + mb6: + c: [192, 200, 208, 216] + d: [3, 4, 5, 6, 7, 8] + k: [3, 5] + t: [6] + s: 2 + act_func: 'swish' + se: True + mb7: + c: [216, 224] + d: [1, 2] + k: [3, 5] + t: [6] + s: 1 + act_func: 'swish' + se: True + last_conv: + c: [1792, 1984] + act_func: 'swish' diff --git a/configs/search/BigNAS/train.yaml b/configs/search/BigNAS/train.yaml new file mode 100644 index 0000000..68abc9f --- /dev/null +++ b/configs/search/BigNAS/train.yaml @@ -0,0 +1,95 @@ +NUM_GPUS: 4 +RNG_SEED: 0 +SPACE: + NAME: 'bignas' +LOADER: + DATASET: 'imagenet' + NUM_CLASSES: 1000 + BATCH_SIZE: 32 + NUM_WORKERS: 4 + USE_VAL: True + TRANSFORM: "auto_augment_tf" +OPTIM: + GRAD_CLIP: 1. + WARMUP_EPOCH: 5 + MAX_EPOCH: 360 + WEIGHT_DECAY: 1.e-5 + BASE_LR: 0.1 + NESTEROV: True +SEARCH: + LOSS_FUN: "cross_entropy_smooth" + LABEL_SMOOTH: 0.1 +TRAIN: + DROP_PATH_PROB: 0.2 +BIGNAS: + SANDWICH_NUM: 4 # max + 2*middle + min + DROP_CONNECT: 0.2 + BN_MOMENTUM: 0. + BN_EPS: 1.e-5 + POST_BN_CALIBRATION_BATCH_NUM: 64 + SUPERNET_CFG: + use_v3_head: True + resolutions: [192, 224, 256, 288] + first_conv: + c: [16, 24] + act_func: 'swish' + s: 2 + mb1: + c: [16, 24] + d: [1, 2] + k: [3, 5] + t: [1] + s: 1 + act_func: 'swish' + se: False + mb2: + c: [24, 32] + d: [3, 4, 5] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: False + mb3: + c: [32, 40] + d: [3, 4, 5, 6] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: True + mb4: + c: [64, 72] + d: [3, 4, 5, 6] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: False + mb5: + c: [112, 120, 128] + d: [3, 4, 5, 6, 7, 8] + k: [3, 5] + t: [4, 5, 6] + s: 1 + act_func: 'swish' + se: True + mb6: + c: [192, 200, 208, 216] + d: [3, 4, 5, 6, 7, 8] + k: [3, 5] + t: [6] + s: 2 + act_func: 'swish' + se: True + mb7: + c: [216, 224] + d: [1, 2] + k: [3, 5] + t: [6] + s: 1 + act_func: 'swish' + se: True + last_conv: + c: [1792, 1984] + act_func: 'swish' \ No newline at end of file diff --git a/examples/search/OFA/train_supernet.sh b/examples/search/OFA/train_supernet.sh index 9466acf..f326d3a 100755 --- a/examples/search/OFA/train_supernet.sh +++ b/examples/search/OFA/train_supernet.sh @@ -1,4 +1,4 @@ -OUT_NAME="OFA_trail_25" +OUT_NAME="OFA_trial_25" TASKS="normal_1 kernel_1 depth_1 depth_2 expand_1 expand_2" for loop in $TASKS diff --git a/scripts/search/OFA/eval_supernet.py b/scripts/search/OFA/eval_supernet.py index 75a5c03..7803b09 100644 --- a/scripts/search/OFA/eval_supernet.py +++ b/scripts/search/OFA/eval_supernet.py @@ -96,7 +96,7 @@ def main(): # load_last_stage_ckpt(cfg.OFA.TASK, cfg.OFA.PHASE) # ofa_trainer.resume() # only load the state_dict of model - # cfg.SEARCH.WEIGHTS = '/home/xfey/XNAS/exp/search/OFA_trail_25/kernel_1/checkpoints/model_epoch_0110.pyth' + # cfg.SEARCH.WEIGHTS = '/home/xfey/XNAS/exp/search/OFA_trial_25/kernel_1/checkpoints/model_epoch_0110.pyth' cfg.SEARCH.WEIGHTS = '/home/xfey/XNAS/tests/weights/ofa_D4_E6_K357' ofa_trainer.resume() @@ -137,10 +137,10 @@ class OFATrainer(KDTrainer): self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.iter_tic() - top1_err = self.test_meter.mb_top1_err.get_win_median() - top5_err = self.test_meter.mb_top5_err.get_win_median() - # self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch) - # self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch) + top1_err = self.test_meter.mb_top1_err.get_win_avg() + top5_err = self.test_meter.mb_top5_err.get_win_avg() + # self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch) + # self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch) # Log epoch stats self.test_meter.log_epoch_stats(cur_epoch) # self.test_meter.reset() diff --git a/scripts/search/OFA/train_supernet.py b/scripts/search/OFA/train_supernet.py index 4aaa9b7..da503f0 100644 --- a/scripts/search/OFA/train_supernet.py +++ b/scripts/search/OFA/train_supernet.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F import xnas.core.config as config +from xnas.datasets.loader import get_normal_dataloader import xnas.logger.meter as meter import xnas.logger.logging as logging from xnas.core.config import cfg @@ -44,7 +45,7 @@ def main(local_rank, world_size): # Loss function criterion = criterion_builder() # Data loaders - [train_loader, valid_loader] = construct_loader() + [train_loader, valid_loader] = get_normal_dataloader() # Optimizers net_params = [ # parameters with weight decay @@ -241,10 +242,10 @@ class OFATrainer(KDTrainer): self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.iter_tic() - top1_err = self.test_meter.mb_top1_err.get_win_median() - top5_err = self.test_meter.mb_top5_err.get_win_median() - # self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch) - # self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch) + top1_err = self.test_meter.mb_top1_err.get_win_avg() + top5_err = self.test_meter.mb_top5_err.get_win_avg() + # self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch) + # self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch) # Log epoch stats self.test_meter.log_epoch_stats(cur_epoch) # self.test_meter.reset() @@ -320,8 +321,8 @@ class OFATrainer(KDTrainer): logger.info("Average@all_subnets top1_err:{} top5_err:{}".format(list_mean(top1errs), list_mean(top5errs))) # Saving best model - if self.best_err > top1_err: - self.best_err = top1_err + if self.best_err > list_mean(top1errs): + self.best_err = list_mean(top1errs) self.saving(cur_epoch, best=True) @@ -331,6 +332,5 @@ if __name__ == '__main__': if torch.cuda.is_available(): cfg.NUM_GPUS = torch.cuda.device_count() - print(cfg.NUM_GPUS) mp.spawn(main, nprocs=cfg.NUM_GPUS, args=(cfg.NUM_GPUS,), join=True) diff --git a/tests/ofa_matrices_test.py b/tests/ofa_matrices_test.py index f84af09..a982d34 100644 --- a/tests/ofa_matrices_test.py +++ b/tests/ofa_matrices_test.py @@ -2,7 +2,7 @@ import os import torch def test_local(): - root = '/home/xfey/XNAS/exp/search/OFA_trail_25/kernel_1/checkpoints/' + root = '/home/xfey/XNAS/exp/search/OFA_trial_25/kernel_1/checkpoints/' filename_prefix = 'model_epoch_' filename_postfix = '.pyth' diff --git a/xnas/runner/criterion.py b/xnas/runner/criterion.py index 45f2ffe..788a25e 100644 --- a/xnas/runner/criterion.py +++ b/xnas/runner/criterion.py @@ -32,6 +32,35 @@ def CrossEntropyLoss_label_smoothed(pred, target, label_smoothing=0.): return CrossEntropyLoss_soft_target(pred, soft_target) +class KLLossSoft(torch.nn.modules.loss._Loss): + """ inplace distillation for image classification + output: output logits of the student network + target: output logits of the teacher network + T: temperature + KL(p||q) = Ep \log p - \Ep log q + """ + def forward(self, output, soft_logits, target=None, temperature=1., alpha=0.9): + output, soft_logits = output / temperature, soft_logits / temperature + soft_target_prob = F.softmax(soft_logits, dim=1) + output_log_prob = F.log_softmax(output, dim=1) + kd_loss = -torch.sum(soft_target_prob * output_log_prob, dim=1) + if target is not None: + n_class = output.size(1) + target = torch.zeros_like(output).scatter(1, target.view(-1, 1), 1) + target = target.unsqueeze(1) + output_log_prob = output_log_prob.unsqueeze(2) + ce_loss = -torch.bmm(target, output_log_prob).squeeze() + loss = alpha * temperature * temperature * kd_loss + (1.0 - alpha) * ce_loss + else: + loss = kd_loss + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + return loss + + class MultiHeadCrossEntropyLoss(nn.Module): def forward(self, preds, targets): assert preds.dim() == 3, preds @@ -50,12 +79,15 @@ class MultiHeadCrossEntropyLoss(nn.Module): SUPPORTED_CRITERIONS = { "cross_entropy": torch.nn.CrossEntropyLoss(), + "cross_entropy_soft": CrossEntropyLoss_soft_target, "cross_entropy_smooth": CrossEntropyLoss_label_smoothed, - "cross_entropy_multihead": MultiHeadCrossEntropyLoss() + "cross_entropy_multihead": MultiHeadCrossEntropyLoss(), + "kl_soft": KLLossSoft(), } -def criterion_builder(): +def criterion_builder(criterion=None): err_str = "Loss function type '{}' not supported" - assert cfg.SEARCH.LOSS_FUN in SUPPORTED_CRITERIONS.keys(), err_str.format(cfg.SEARCH.LOSS_FUN) - return SUPPORTED_CRITERIONS[cfg.SEARCH.LOSS_FUN] + loss_fun = cfg.SEARCH.LOSS_FUN if criterion is None else criterion + assert loss_fun in SUPPORTED_CRITERIONS.keys(), err_str.format(loss_fun) + return SUPPORTED_CRITERIONS[loss_fun] diff --git a/xnas/runner/optimizer.py b/xnas/runner/optimizer.py index e523705..c6699f3 100644 --- a/xnas/runner/optimizer.py +++ b/xnas/runner/optimizer.py @@ -1,7 +1,6 @@ """Optimizers.""" import torch -import torch.nn as nn from xnas.core.config import cfg diff --git a/xnas/spaces/BigNAS/cnn.py b/xnas/spaces/BigNAS/cnn.py new file mode 100644 index 0000000..578ab17 --- /dev/null +++ b/xnas/spaces/BigNAS/cnn.py @@ -0,0 +1,653 @@ +# Implementation adapted from AttentiveNAS: https://github.com/facebookresearch/AttentiveNAS + +import random +from copy import deepcopy +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from xnas.spaces.OFA.ops import ResidualBlock +from xnas.spaces.OFA.dynamic_ops import DynamicLinearLayer +from xnas.spaces.OFA.utils import val2list, make_divisible +from xnas.spaces.BigNAS.dynamic_layers import DynamicMBConvLayer, DynamicConvLayer, DynamicShortcutLayer + + +class BigNASStaticModel(nn.Module): + + def __init__(self, first_conv, blocks, last_conv, classifier, resolution, use_v3_head=True): + super(BigNASStaticModel, self).__init__() + + self.first_conv = first_conv + self.blocks = nn.ModuleList(blocks) + self.last_conv = last_conv + self.classifier = classifier + + self.resolution = resolution #input size + self.use_v3_head = use_v3_head + + def forward(self, x): + # resize input to target resolution first + # Rule: transform images into different sizes + if x.size(-1) != self.resolution: + x = F.interpolate(x, size=self.resolution, mode='bicubic') + + x = self.first_conv(x) + for block in self.blocks: + x = block(x) + x = self.last_conv(x) + if not self.use_v3_head: + x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling + x = torch.squeeze(x) + x = self.classifier(x) + return x + + + @property + def module_str(self): + _str = self.first_conv.module_str + '\n' + for block in self.blocks: + _str += block.module_str + '\n' + #_str += self.last_conv.module_str + '\n' + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + 'name': BigNASStaticModel.__name__, + 'bn': self.get_bn_param(), + 'first_conv': self.first_conv.config, + 'blocks': [ + block.config for block in self.blocks + ], + #'last_conv': self.last_conv.config, + 'classifier': self.classifier.config, + 'resolution': self.resolution + } + + + def weight_initialization(self): + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + @staticmethod + def build_from_config(config): + raise NotImplementedError + + def set_bn_param(self, momentum, eps): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm): + if momentum is not None: + m.momentum = float(momentum) + else: + m.momentum = None + m.eps = float(eps) + return + + def get_bn_param(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm): + return { + 'momentum': m.momentum, + 'eps': m.eps, + } + return None + + def reset_running_stats_for_calibration(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm): + m.training = True + m.momentum = None # cumulative moving average + m.reset_running_stats() + + +class BigNASDynamicModel(nn.Module): + + def __init__(self, supernet_cfg, n_classes=1000, bn_param=(0., 1e-5)): + super(BigNASDynamicModel, self).__init__() + + self.supernet_cfg = supernet_cfg + self.n_classes = n_classes + self.use_v3_head = getattr(self.supernet_cfg, 'use_v3_head', False) + self.stage_names = ['first_conv', 'mb1', 'mb2', 'mb3', 'mb4', 'mb5', 'mb6', 'mb7', 'last_conv'] + + self.width_list, self.depth_list, self.ks_list, self.expand_ratio_list = [], [], [], [] + for name in self.stage_names: + block_cfg = getattr(self.supernet_cfg, name) + self.width_list.append(block_cfg.c) + if name.startswith('mb'): + self.depth_list.append(block_cfg.d) + self.ks_list.append(block_cfg.k) + self.expand_ratio_list.append(block_cfg.t) + self.resolution_list = self.supernet_cfg.resolutions + + self.cfg_candidates = { + 'resolution': self.resolution_list , + 'width': self.width_list, + 'depth': self.depth_list, + 'kernel_size': self.ks_list, + 'expand_ratio': self.expand_ratio_list + } + + #first conv layer, including conv, bn, act + out_channel_list, act_func, stride = \ + self.supernet_cfg.first_conv.c, self.supernet_cfg.first_conv.act_func, self.supernet_cfg.first_conv.s + self.first_conv = DynamicConvLayer( + in_channel_list=val2list(3), out_channel_list=out_channel_list, + kernel_size=3, stride=stride, act_func=act_func, + ) + + # inverted residual blocks + self.block_group_info = [] + blocks = [] + _block_index = 0 + feature_dim = out_channel_list + for stage_id, key in enumerate(self.stage_names[1:-1]): + block_cfg = getattr(self.supernet_cfg, key) + width = block_cfg.c + n_block = max(block_cfg.d) + act_func = block_cfg.act_func + ks = block_cfg.k + expand_ratio_list = block_cfg.t + use_se = block_cfg.se + + self.block_group_info.append([_block_index + i for i in range(n_block)]) + _block_index += n_block + + output_channel = width + for i in range(n_block): + stride = block_cfg.s if i == 0 else 1 + if min(expand_ratio_list) >= 4: + expand_ratio_list = [_s for _s in expand_ratio_list if _s >= 4] if i == 0 else expand_ratio_list + mobile_inverted_conv = DynamicMBConvLayer( + in_channel_list=feature_dim, + out_channel_list=output_channel, + kernel_size_list=ks, + expand_ratio_list=expand_ratio_list, + stride=stride, + act_func=act_func, + use_se=use_se, + channels_per_group=getattr(self.supernet_cfg, 'channels_per_group', 1) + ) + # Rule: add skip-connect, and use 2x2 AvgPool or 1x1 Conv for adaptation + shortcut = DynamicShortcutLayer(feature_dim, output_channel, reduction=stride) + blocks.append(ResidualBlock(mobile_inverted_conv, shortcut)) + feature_dim = output_channel + self.blocks = nn.ModuleList(blocks) + + last_channel, act_func = self.supernet_cfg.last_conv.c, self.supernet_cfg.last_conv.act_func + if not self.use_v3_head: + self.last_conv = DynamicConvLayer( + in_channel_list=feature_dim, out_channel_list=last_channel, + kernel_size=1, act_func=act_func, + ) + else: + expand_feature_dim = [f_dim * 6 for f_dim in feature_dim] + self.last_conv = nn.Sequential(OrderedDict([ + ('final_expand_layer', DynamicConvLayer( + feature_dim, expand_feature_dim, kernel_size=1, use_bn=True, act_func=act_func) + ), + ('pool', nn.AdaptiveAvgPool2d((1,1))), + ('feature_mix_layer', DynamicConvLayer( + in_channel_list=expand_feature_dim, out_channel_list=last_channel, + kernel_size=1, act_func=act_func, use_bn=False,) + ), + ])) + + #final conv layer + self.classifier = DynamicLinearLayer( + in_features_list=last_channel, out_features=n_classes, bias=True + ) + + # set bn param + self.set_bn_param(momentum=bn_param[0], eps=bn_param[1]) + + # runtime_depth + self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info] + + self.zero_residual_block_bn_weights() + + self.active_dropout_rate = 0 + self.active_drop_connect_rate = 0 + self.active_resolution = 224 + + # Rule: Initialize learnable coefficient \gamma=0 + def zero_residual_block_bn_weights(self): + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, ResidualBlock): + if isinstance(m.mobile_inverted_conv, DynamicMBConvLayer) and m.shortcut is not None: + m.mobile_inverted_conv.point_linear.bn.bn.weight.zero_() + + @staticmethod + def name(): + return 'BigNASDynamicModel' + + def forward(self, x): + # resize input to target resolution first + if x.size(-1) != self.active_resolution: + x = F.interpolate(x, size=self.active_resolution, mode='bicubic') + + # first conv + x = self.first_conv(x) + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + x = self.blocks[idx](x) + + x = self.last_conv(x) + x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling + x = torch.squeeze(x) + + if self.active_dropout_rate > 0 and self.training: + x = F.dropout(x, p = self.active_dropout_rate) + + x = self.classifier(x) + return x + + + @property + def module_str(self): + _str = self.first_conv.module_str + '\n' + _str += self.blocks[0].module_str + '\n' + + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + _str += self.blocks[idx].module_str + '\n' + if not self.use_v3_head: + _str += self.last_conv.module_str + '\n' + else: + _str += self.last_conv.final_expand_layer.module_str + '\n' + _str += self.last_conv.feature_mix_layer.module_str + '\n' + _str += self.classifier.module_str + '\n' + return _str + + @property + def config(self): + return { + 'name': BigNASDynamicModel.__name__, + 'bn': self.get_bn_param(), + 'first_conv': self.first_conv.config, + 'blocks': [ + block.config for block in self.blocks + ], + 'last_conv': self.last_conv.config if not self.use_v3_head else None, + 'final_expand_layer': self.last_conv.final_expand_layer if self.use_v3_head else None, + 'feature_mix_layer': self.last_conv.feature_mix_layer if self.use_v3_head else None, + 'classifier': self.classifier.config, + 'resolution': self.active_resolution + } + + + @staticmethod + def build_from_config(config): + raise NotImplementedError + + def get_parameters(self, keys=None, mode="include"): + if keys is None: + for name, param in self.named_parameters(): + if param.requires_grad: + yield param + elif mode == "include": + for name, param in self.named_parameters(): + flag = False + for key in keys: + if key in name: + flag = True + break + if flag and param.requires_grad: + yield param + elif mode == "exclude": + for name, param in self.named_parameters(): + flag = True + for key in keys: + if key in name: + flag = False + break + if flag and param.requires_grad: + yield param + else: + raise ValueError("do not support: %s" % mode) + + def set_bn_param(self, momentum, eps): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm): + if momentum is not None: + m.momentum = float(momentum) + else: + m.momentum = None + m.eps = float(eps) + return + + def get_bn_param(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm): + return { + 'momentum': m.momentum, + 'eps': m.eps, + } + return None + + """ set, sample and get active sub-networks """ + def set_active_subnet(self, resolution=224, width=None, depth=None, kernel_size=None, expand_ratio=None, **kwargs): + assert len(depth) == len(kernel_size) == len(expand_ratio) == len(width) - 2 + #set resolution + self.active_resolution = resolution + + # first conv + self.first_conv.active_out_channel = width[0] + + for stage_id, (c, k, e, d) in enumerate(zip(width[1:-1], kernel_size, expand_ratio, depth)): + start_idx, end_idx = min(self.block_group_info[stage_id]), max(self.block_group_info[stage_id]) + for block_id in range(start_idx, start_idx+d): + block = self.blocks[block_id] + #block output channels + block.mobile_inverted_conv.active_out_channel = c + if block.shortcut is not None: + block.shortcut.active_out_channel = c + + #dw kernel size + block.mobile_inverted_conv.active_kernel_size = k + + #dw expansion ration + block.mobile_inverted_conv.active_expand_ratio = e + + #IRBlocks repated times + for i, d in enumerate(depth): + self.runtime_depth[i] = min(len(self.block_group_info[i]), d) + + #last conv + if not self.use_v3_head: + self.last_conv.active_out_channel = width[-1] + else: + # default expansion ratio: 6 + self.last_conv.final_expand_layer.active_out_channel = width[-2] * 6 + self.last_conv.feature_mix_layer.active_out_channel = width[-1] + + + def get_active_subnet_settings(self): + r = self.active_resolution + width, depth, kernel_size, expand_ratio= [], [], [], [] + + #first conv + width.append(self.first_conv.active_out_channel) + for stage_id in range(len(self.block_group_info)): + start_idx = min(self.block_group_info[stage_id]) + block = self.blocks[start_idx] #first block + width.append(block.mobile_inverted_conv.active_out_channel) + kernel_size.append(block.mobile_inverted_conv.active_kernel_size) + expand_ratio.append(block.mobile_inverted_conv.active_expand_ratio) + depth.append(self.runtime_depth[stage_id]) + + if not self.use_v3_head: + width.append(self.last_conv.active_out_channel) + else: + width.append(self.last_conv.feature_mix_layer.active_out_channel) + + return { + 'resolution': r, + 'width': width, + 'kernel_size': kernel_size, + 'expand_ratio': expand_ratio, + 'depth': depth, + } + + def set_dropout_rate(self, dropout=0, drop_connect=0, drop_connect_only_last_two_stages=True): + self.active_dropout_rate = dropout + for idx, block in enumerate(self.blocks): + if drop_connect_only_last_two_stages: + if idx not in self.block_group_info[-1] + self.block_group_info[-2]: + continue + this_drop_connect_rate = drop_connect * float(idx) / len(self.blocks) + block.drop_connect_rate = this_drop_connect_rate + + + def sample_min_subnet(self): + return self._sample_active_subnet(min_net=True) + + + def sample_max_subnet(self): + return self._sample_active_subnet(max_net=True) + + + def sample_active_subnet(self, compute_flops=False): + cfg = self._sample_active_subnet( + False, False + ) + if compute_flops: + cfg['flops'] = self.compute_active_subnet_flops() + return cfg + + + def sample_active_subnet_within_range(self, targeted_min_flops, targeted_max_flops): + while True: + cfg = self._sample_active_subnet() + cfg['flops'] = self.compute_active_subnet_flops() + if cfg['flops'] >= targeted_min_flops and cfg['flops'] <= targeted_max_flops: + return cfg + + def _sample_active_subnet(self, min_net=False, max_net=False): + + sample_cfg = lambda candidates, sample_min, sample_max: \ + min(candidates) if sample_min else (max(candidates) if sample_max else random.choice(candidates)) + + cfg = {} + # sample a resolution + cfg['resolution'] = sample_cfg(self.cfg_candidates['resolution'], min_net, max_net) + for k in ['width', 'depth', 'kernel_size', 'expand_ratio']: + cfg[k] = [] + for vv in self.cfg_candidates[k]: + cfg[k].append(sample_cfg(val2list(vv), min_net, max_net)) + + self.set_active_subnet( + cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio'] + ) + return cfg + + + def mutate_and_reset(self, cfg, prob=0.1, keep_resolution=False): + cfg = deepcopy(cfg) + pick_another = lambda x, candidates: x if len(candidates) == 1 else random.choice([v for v in candidates if v != x]) + # sample a resolution + r = random.random() + if r < prob and not keep_resolution: + cfg['resolution'] = pick_another(cfg['resolution'], self.cfg_candidates['resolution']) + + # sample channels, depth, kernel_size, expand_ratio + for k in ['width', 'depth', 'kernel_size', 'expand_ratio']: + for _i, _v in enumerate(cfg[k]): + r = random.random() + if r < prob: + cfg[k][_i] = pick_another(cfg[k][_i], val2list(self.cfg_candidates[k][_i])) + + self.set_active_subnet( + cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio'] + ) + return cfg + + + def crossover_and_reset(self, cfg1, cfg2, p=0.5): + def _cross_helper(g1, g2, prob): + assert type(g1) == type(g2) + if isinstance(g1, int): + return g1 if random.random() < prob else g2 + elif isinstance(g1, list): + return [v1 if random.random() < prob else v2 for v1, v2 in zip(g1, g2)] + else: + raise NotImplementedError + + cfg = {} + cfg['resolution'] = cfg1['resolution'] if random.random() < p else cfg2['resolution'] + for k in ['width', 'depth', 'kernel_size', 'expand_ratio']: + cfg[k] = _cross_helper(cfg1[k], cfg2[k], p) + + self.set_active_subnet( + cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio'] + ) + return cfg + + + def get_active_subnet(self, preserve_weight=True): + with torch.no_grad(): + first_conv = self.first_conv.get_active_subnet(3, preserve_weight) + + blocks = [] + input_channel = first_conv.out_channels + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append(ResidualBlock( + self.blocks[idx].mobile_inverted_conv.get_active_subnet(input_channel, preserve_weight), + self.blocks[idx].shortcut.get_active_subnet(input_channel, preserve_weight) if self.blocks[idx].shortcut is not None else None + )) + input_channel = stage_blocks[-1].mobile_inverted_conv.out_channels + blocks += stage_blocks + + if not self.use_v3_head: + last_conv = self.last_conv.get_active_subnet(input_channel, preserve_weight) + in_features = last_conv.out_channels + else: + final_expand_layer = self.last_conv.final_expand_layer.get_active_subnet(input_channel, preserve_weight) + feature_mix_layer = self.last_conv.feature_mix_layer.get_active_subnet(input_channel*6, preserve_weight) + in_features = feature_mix_layer.out_channels + last_conv = nn.Sequential( + final_expand_layer, + nn.AdaptiveAvgPool2d((1,1)), + feature_mix_layer + ) + + classifier = self.classifier.get_active_subnet(in_features, preserve_weight) + + _subnet = BigNASStaticModel( + first_conv, blocks, last_conv, classifier, self.active_resolution, use_v3_head=self.use_v3_head + ) + _subnet.set_bn_param(**self.get_bn_param()) + return _subnet + + + def compute_active_subnet_flops(self): + + def count_conv(c_in, c_out, size_out, groups, k): + kernel_ops = k**2 + output_elements = c_out * size_out**2 + ops = c_in * output_elements * kernel_ops / groups + return ops + + def count_linear(c_in, c_out): + return c_in * c_out + + total_ops = 0 + + c_in = 3 + size_out = self.active_resolution // self.first_conv.stride + c_out = self.first_conv.active_out_channel + + total_ops += count_conv(c_in, c_out, size_out, 1, 3) + c_in = c_out + + # mb blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + block = self.blocks[idx] + c_middle = make_divisible(round(c_in * block.mobile_inverted_conv.active_expand_ratio), 8) + # 1*1 conv + if block.mobile_inverted_conv.inverted_bottleneck is not None: + total_ops += count_conv(c_in, c_middle, size_out, 1, 1) + # dw conv + stride = 1 if idx > active_idx[0] else block.mobile_inverted_conv.stride + if size_out % stride == 0: + size_out = size_out // stride + else: + size_out = (size_out +1) // stride + total_ops += count_conv(c_middle, c_middle, size_out, c_middle, block.mobile_inverted_conv.active_kernel_size) + # 1*1 conv + c_out = block.mobile_inverted_conv.active_out_channel + total_ops += count_conv(c_middle, c_out, size_out, 1, 1) + #se + if block.mobile_inverted_conv.use_se: + num_mid = make_divisible(c_middle // block.mobile_inverted_conv.depth_conv.se.reduction, divisor=8) + total_ops += count_conv(c_middle, num_mid, 1, 1, 1) * 2 + if block.shortcut and c_in != c_out: + total_ops += count_conv(c_in, c_out, size_out, 1, 1) + c_in = c_out + + if not self.use_v3_head: + c_out = self.last_conv.active_out_channel + total_ops += count_conv(c_in, c_out, size_out, 1, 1) + else: + c_expand = self.last_conv.final_expand_layer.active_out_channel + c_out = self.last_conv.feature_mix_layer.active_out_channel + total_ops += count_conv(c_in, c_expand, size_out, 1, 1) + total_ops += count_conv(c_expand, c_out, 1, 1, 1) + + # n_classes + total_ops += count_linear(c_out, self.n_classes) + return total_ops / 1e6 + + + def load_weights_from_pretrained_models(self, checkpoint_path): + with open(checkpoint_path, 'rb') as f: + checkpoint = torch.load(f, map_location='cpu') + assert isinstance(checkpoint, dict) + pretrained_state_dicts = checkpoint['model_state'] + for k, v in self.state_dict().items(): + # name = 'module.' + k if not k.startswith('module') else k + name = k + v.copy_(pretrained_state_dicts[name]) + + +def _BigNAS_CNN(): + from xnas.core.config import cfg + bn_momentum = cfg.BIGNAS.BN_MOMENTUM + bn_eps = cfg.BIGNAS.BN_EPS + return BigNASDynamicModel( + cfg.BIGNAS.SUPERNET_CFG, + cfg.LOADER.NUM_CLASSES, + (bn_momentum, bn_eps), + ) + +def _infer_BigNAS_CNN(): + from xnas.core.config import cfg + bn_momentum = cfg.BIGNAS.BN_MOMENTUM + bn_eps = cfg.BIGNAS.BN_EPS + supernet = BigNASDynamicModel( + cfg.BIGNAS.SUPERNET_CFG, + cfg.LOADER.NUM_CLASSES, + (bn_momentum, bn_eps), + ) + # namespace changed: pareto_models.supernet_checkpoint_path + supernet.load_weights_from_pretrained_models(cfg.SEARCH.WEIGHT) + # namespace created: active_subnet.* + supernet.set_active_subnet( + resolution=cfg.BIGNAS.ACTIVE_SUBNET.RESOLUTION, + width = cfg.BIGNAS.ACTIVE_SUBNET.WIDTH, + depth = cfg.BIGNAS.ACTIVE_SUBNET.DEPTH, + kernel_size = cfg.BIGNAS.ACTIVE_SUBNET.KERNEL_SIZE, + expand_ratio = cfg.BIGNAS.ACTIVE_SUBNET.EXPAND_RATIO, + ) + model = supernet.get_active_subnet() + # house-keeping stuff: may using different values with supernet + model.set_bn_param(momentum=bn_momentum, eps=bn_eps) + del supernet + return model diff --git a/xnas/spaces/BigNAS/dynamic_layers.py b/xnas/spaces/BigNAS/dynamic_layers.py new file mode 100644 index 0000000..dfe5557 --- /dev/null +++ b/xnas/spaces/BigNAS/dynamic_layers.py @@ -0,0 +1,331 @@ +from collections import OrderedDict + +import torch.nn as nn +import torch.nn.functional as F + + +from xnas.spaces.OFA.utils import val2list +from xnas.spaces.OFA.ops import SEModule, ConvLayer, ShortcutLayer, build_activation, make_divisible +from xnas.spaces.OFA.dynamic_ops import DynamicConv2d, DynamicSE, copy_bn +from xnas.spaces.BigNAS.ops import MBConvLayer +from xnas.spaces.BigNAS.dynamic_ops import DynamicSeparableConv2d, DynamicBatchNorm2d + + +class DynamicMBConvLayer(nn.Module): + + def __init__(self, in_channel_list, out_channel_list, + kernel_size_list=3, expand_ratio_list=6, stride=1, act_func='relu6', use_se=False, channels_per_group=1): + super(DynamicMBConvLayer, self).__init__() + + self.in_channel_list = val2list(in_channel_list) + self.out_channel_list = val2list(out_channel_list) + + self.kernel_size_list = val2list(kernel_size_list, 1) + self.expand_ratio_list = val2list(expand_ratio_list, 1) + + self.stride = stride + self.act_func = act_func + self.use_se = use_se + self.channels_per_group = channels_per_group + + # build modules + max_middle_channel = round(max(self.in_channel_list) * max(self.expand_ratio_list)) + if max(self.expand_ratio_list) == 1: + self.inverted_bottleneck = None + else: + self.inverted_bottleneck = nn.Sequential(OrderedDict([ + ('conv', DynamicConv2d(max(self.in_channel_list), max_middle_channel)), + ('bn', DynamicBatchNorm2d(max_middle_channel)), + ('act', build_activation(self.act_func, inplace=True)), + ])) + + self.depth_conv = nn.Sequential(OrderedDict([ + ('conv', DynamicSeparableConv2d(max_middle_channel, self.kernel_size_list, stride=self.stride, channels_per_group=self.channels_per_group)), + ('bn', DynamicBatchNorm2d(max_middle_channel)), + ('act', build_activation(self.act_func, inplace=True)) + ])) + if self.use_se: + self.depth_conv.add_module('se', DynamicSE(max_middle_channel)) + + self.point_linear = nn.Sequential(OrderedDict([ + ('conv', DynamicConv2d(max_middle_channel, max(self.out_channel_list))), + ('bn', DynamicBatchNorm2d(max(self.out_channel_list))), + ])) + + self.active_kernel_size = max(self.kernel_size_list) + self.active_expand_ratio = max(self.expand_ratio_list) + self.active_out_channel = max(self.out_channel_list) + + def forward(self, x): + in_channel = x.size(1) + + if self.inverted_bottleneck is not None: + self.inverted_bottleneck.conv.active_out_channel = \ + make_divisible(round(in_channel * self.active_expand_ratio), 8) + + self.depth_conv.conv.active_kernel_size = self.active_kernel_size + self.point_linear.conv.active_out_channel = self.active_out_channel + + if self.inverted_bottleneck is not None: + x = self.inverted_bottleneck(x) + x = self.depth_conv(x) + x = self.point_linear(x) + return x + + @property + def module_str(self): + if self.use_se: + return 'SE(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size) + else: + return '(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size) + + @property + def config(self): + return { + 'name': DynamicMBConvLayer.__name__, + 'in_channel_list': self.in_channel_list, + 'out_channel_list': self.out_channel_list, + 'kernel_size_list': self.kernel_size_list, + 'expand_ratio_list': self.expand_ratio_list, + 'stride': self.stride, + 'act_func': self.act_func, + 'use_se': self.use_se, + 'channels_per_group': self.channels_per_group, + } + + @staticmethod + def build_from_config(config): + return DynamicMBConvLayer(**config) + + ############################################################################################ + + def get_active_subnet(self, in_channel, preserve_weight=True): + middle_channel = make_divisible(round(in_channel * self.active_expand_ratio), 8) + channels_per_group = self.depth_conv.conv.channels_per_group + + # build the new layer + sub_layer = MBConvLayer( + in_channel, self.active_out_channel, self.active_kernel_size, self.stride, self.active_expand_ratio, + act_func=self.act_func, mid_channels=middle_channel, use_se=self.use_se, channels_per_group=channels_per_group + ) + sub_layer = sub_layer.to(self.parameters().__next__().device) + + if not preserve_weight: + return sub_layer + + # copy weight from current layer + if sub_layer.inverted_bottleneck is not None: + sub_layer.inverted_bottleneck.conv.weight.data.copy_( + self.inverted_bottleneck.conv.conv.weight.data[:middle_channel, :in_channel, :, :] + ) + copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn) + + sub_layer.depth_conv.conv.weight.data.copy_( + self.depth_conv.conv.get_active_filter(middle_channel, self.active_kernel_size).data + ) + copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn) + + if self.use_se: + se_mid = make_divisible(middle_channel // SEModule.REDUCTION, divisor=8) + sub_layer.depth_conv.se.fc.reduce.weight.data.copy_( + self.depth_conv.se.fc.reduce.weight.data[:se_mid, :middle_channel, :, :] + ) + sub_layer.depth_conv.se.fc.reduce.bias.data.copy_(self.depth_conv.se.fc.reduce.bias.data[:se_mid]) + + sub_layer.depth_conv.se.fc.expand.weight.data.copy_( + self.depth_conv.se.fc.expand.weight.data[:middle_channel, :se_mid, :, :] + ) + sub_layer.depth_conv.se.fc.expand.bias.data.copy_(self.depth_conv.se.fc.expand.bias.data[:middle_channel]) + + sub_layer.point_linear.conv.weight.data.copy_( + self.point_linear.conv.conv.weight.data[:self.active_out_channel, :middle_channel, :, :] + ) + copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn) + + return sub_layer + + def re_organize_middle_weights(self, expand_ratio_stage=0): + # importance = torch.sum(torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3)) + # if expand_ratio_stage > 0: + # sorted_expand_list = copy.deepcopy(self.expand_ratio_list) + # sorted_expand_list.sort(reverse=True) + # target_width = sorted_expand_list[expand_ratio_stage] + # target_width = round(max(self.in_channel_list) * target_width) + # importance[target_width:] = torch.arange(0, target_width - importance.size(0), -1) + + # sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True) + # self.point_linear.conv.conv.weight.data = torch.index_select( + # self.point_linear.conv.conv.weight.data, 1, sorted_idx + # ) + + # adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx) + # self.depth_conv.conv.conv.weight.data = torch.index_select( + # self.depth_conv.conv.conv.weight.data, 0, sorted_idx + # ) + + # if self.use_se: + # # se expand: output dim 0 reorganize + # se_expand = self.depth_conv.se.fc.expand + # se_expand.weight.data = torch.index_select(se_expand.weight.data, 0, sorted_idx) + # se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx) + # # se reduce: input dim 1 reorganize + # se_reduce = self.depth_conv.se.fc.reduce + # se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 1, sorted_idx) + # # middle weight reorganize + # se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3)) + # se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True) + + # se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx) + # se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx) + # se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx) + + # # TODO if inverted_bottleneck is None, the previous layer should be reorganized accordingly + # if self.inverted_bottleneck is not None: + # adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx) + # self.inverted_bottleneck.conv.conv.weight.data = torch.index_select( + # self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx + # ) + # return None + # else: + # return sorted_idx + raise NotImplementedError + + +class DynamicConvLayer(nn.Module): + + def __init__(self, in_channel_list, out_channel_list, kernel_size=3, stride=1, dilation=1, + use_bn=True, act_func='relu6'): + super(DynamicConvLayer, self).__init__() + + self.in_channel_list = val2list(in_channel_list) + self.out_channel_list = val2list(out_channel_list) + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.use_bn = use_bn + self.act_func = act_func + + self.conv = DynamicConv2d( + max_in_channels=max(self.in_channel_list), max_out_channels=max(self.out_channel_list), + kernel_size=self.kernel_size, stride=self.stride, dilation=self.dilation, + ) + if self.use_bn: + self.bn = DynamicBatchNorm2d(max(self.out_channel_list)) + + if self.act_func is not None: + self.act = build_activation(self.act_func, inplace=True) + + self.active_out_channel = max(self.out_channel_list) + + def forward(self, x): + self.conv.active_out_channel = self.active_out_channel + + x = self.conv(x) + if self.use_bn: + x = self.bn(x) + if self.act_func is not None: + x = self.act(x) + return x + + @property + def module_str(self): + return 'DyConv(O%d, K%d, S%d)' % (self.active_out_channel, self.kernel_size, self.stride) + + @property + def config(self): + return { + 'name': DynamicConvLayer.__name__, + 'in_channel_list': self.in_channel_list, + 'out_channel_list': self.out_channel_list, + 'kernel_size': self.kernel_size, + 'stride': self.stride, + 'dilation': self.dilation, + 'use_bn': self.use_bn, + 'act_func': self.act_func, + } + + @staticmethod + def build_from_config(config): + return DynamicConvLayer(**config) + + def get_active_subnet(self, in_channel, preserve_weight=True): + sub_layer = ConvLayer( + in_channel, self.active_out_channel, self.kernel_size, self.stride, self.dilation, + use_bn=self.use_bn, act_func=self.act_func + ) + sub_layer = sub_layer.to(self.parameters().__next__().device) + + if not preserve_weight: + return sub_layer + + sub_layer.conv.weight.data.copy_(self.conv.conv.weight.data[:self.active_out_channel, :in_channel, :, :]) + if self.use_bn: + copy_bn(sub_layer.bn, self.bn.bn) + + return sub_layer + + +class DynamicShortcutLayer(nn.Module): + + def __init__(self, in_channel_list, out_channel_list, reduction=1): + super(DynamicShortcutLayer, self).__init__() + + self.in_channel_list = val2list(in_channel_list) + self.out_channel_list = val2list(out_channel_list) + self.reduction = reduction + + self.conv = DynamicConv2d( + max_in_channels=max(self.in_channel_list), max_out_channels=max(self.out_channel_list), + kernel_size=1, stride=1, + ) + + self.active_out_channel = max(self.out_channel_list) + + def forward(self, x): + in_channel = x.size(1) + + #identity mapping + if in_channel == self.active_out_channel and self.reduction == 1: + return x + #average pooling, if size doesn't match + if self.reduction > 1: + padding = 0 if x.size(-1) % 2 == 0 else 1 + x = F.avg_pool2d(x, self.reduction, padding=padding) + + #1*1 conv, if #channels doesn't match + if in_channel != self.active_out_channel: + self.conv.active_out_channel = self.active_out_channel + x = self.conv(x) + return x + + @property + def module_str(self): + return 'DyShortcut(O%d, R%d)' % (self.active_out_channel, self.reduction) + + @property + def config(self): + return { + 'name': DynamicShortcutLayer.__name__, + 'in_channel_list': self.in_channel_list, + 'out_channel_list': self.out_channel_list, + 'reduction': self.reduction, + } + + @staticmethod + def build_from_config(config): + return DynamicShortcutLayer(**config) + + def get_active_subnet(self, in_channel, preserve_weight=True): + sub_layer = ShortcutLayer( + in_channel, self.active_out_channel, self.reduction + ) + sub_layer = sub_layer.to(self.parameters().__next__().device) + + if not preserve_weight: + return sub_layer + + sub_layer.conv.weight.data.copy_(self.conv.conv.weight.data[:self.active_out_channel, :in_channel, :, :]) + + return sub_layer + + diff --git a/xnas/spaces/BigNAS/dynamic_ops.py b/xnas/spaces/BigNAS/dynamic_ops.py new file mode 100644 index 0000000..f67b47c --- /dev/null +++ b/xnas/spaces/BigNAS/dynamic_ops.py @@ -0,0 +1,181 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.autograd.function import Function + +from xnas.spaces.OFA.ops import get_same_padding +from xnas.spaces.OFA.dynamic_ops import sub_filter_start_end + + + +class DynamicSeparableConv2d(nn.Module): + # KERNEL_TRANSFORM_MODE = None # None or 1 + + def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1, channels_per_group=1): + super(DynamicSeparableConv2d, self).__init__() + + self.max_in_channels = max_in_channels + self.channels_per_group = channels_per_group + assert self.max_in_channels % self.channels_per_group == 0 + self.kernel_size_list = kernel_size_list + self.stride = stride + self.dilation = dilation + + self.conv = nn.Conv2d( + self.max_in_channels, self.max_in_channels, max(self.kernel_size_list), self.stride, + groups=self.max_in_channels // self.channels_per_group, bias=False, + ) + + self._ks_set = list(set(self.kernel_size_list)) + self._ks_set.sort() # e.g., [3, 5, 7] + # if self.KERNEL_TRANSFORM_MODE is not None: + # # register scaling parameters + # # 7to5_matrix, 5to3_matrix + # scale_params = {} + # for i in range(len(self._ks_set) - 1): + # ks_small = self._ks_set[i] + # ks_larger = self._ks_set[i + 1] + # param_name = '%dto%d' % (ks_larger, ks_small) + # scale_params['%s_matrix' % param_name] = Parameter(torch.eye(ks_small ** 2)) + # for name, param in scale_params.items(): + # self.register_parameter(name, param) + + self.active_kernel_size = max(self.kernel_size_list) + + def get_active_filter(self, in_channel, kernel_size): + out_channel = in_channel + max_kernel_size = max(self.kernel_size_list) + + start, end = sub_filter_start_end(max_kernel_size, kernel_size) + filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end] + # if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size: + # start_filter = self.conv.weight[:out_channel, :in_channel, :, :] # start with max kernel + # for i in range(len(self._ks_set) - 1, 0, -1): + # src_ks = self._ks_set[i] + # if src_ks <= kernel_size: + # break + # target_ks = self._ks_set[i - 1] + # start, end = sub_filter_start_end(src_ks, target_ks) + # _input_filter = start_filter[:, :, start:end, start:end] + # _input_filter = _input_filter.contiguous() + # _input_filter = _input_filter.view(_input_filter.size(0), _input_filter.size(1), -1) + # _input_filter = _input_filter.view(-1, _input_filter.size(2)) + # _input_filter = F.linear( + # _input_filter, self.__getattr__('%dto%d_matrix' % (src_ks, target_ks)), + # ) + # _input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks ** 2) + # _input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks, target_ks) + # start_filter = _input_filter + # filters = start_filter + return filters + + def forward(self, x, kernel_size=None): + if kernel_size is None: + kernel_size = self.active_kernel_size + in_channel = x.size(1) + assert in_channel % self.channels_per_group == 0 + + filters = self.get_active_filter(in_channel, kernel_size).contiguous() + + padding = get_same_padding(kernel_size) + y = F.conv2d( + x, filters, None, self.stride, padding, self.dilation, in_channel // self.channels_per_group + ) + return y + + +class AllReduce(Function): + @staticmethod + def forward(ctx, input): + input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())] + # Use allgather instead of allreduce since I don't trust in-place operations .. + dist.all_gather(input_list, input, async_op=False) + inputs = torch.stack(input_list, dim=0) + return torch.sum(inputs, dim=0) + + @staticmethod + def backward(ctx, grad_output): + dist.all_reduce(grad_output, async_op=False) + return grad_output + + +class DynamicBatchNorm2d(nn.Module): + ''' + 1. doesn't acculate bn statistics, (momentum=0.) + 2. calculate BN statistics of all subnets after training + 3. bn weights are shared + https://arxiv.org/abs/1903.05134 + https://detectron2.readthedocs.io/_modules/detectron2/layers/batch_norm.html + ''' + #SET_RUNNING_STATISTICS = False + + def __init__(self, max_feature_dim): + super(DynamicBatchNorm2d, self).__init__() + + self.max_feature_dim = max_feature_dim + self.bn = nn.BatchNorm2d(self.max_feature_dim) + + # self.exponential_average_factor = 0 # doesn't acculate bn stats + self.need_sync = False # sync-batchnormalization, suggested to use in bignas + + # reserved to tracking the performance of the largest and smallest network + self.bn_tracking = nn.ModuleList( + [ + nn.BatchNorm2d(self.max_feature_dim, affine=False), + nn.BatchNorm2d(self.max_feature_dim, affine=False) + ] + ) + + def forward(self, x): + feature_dim = x.size(1) + if not self.training: + raise ValueError('DynamicBN only supports training') + + bn = self.bn + # need_sync + if not self.need_sync: + return F.batch_norm( + x, bn.running_mean[:feature_dim], bn.running_var[:feature_dim], bn.weight[:feature_dim], + bn.bias[:feature_dim], bn.training or not bn.track_running_stats, + bn.momentum, bn.eps, + ) + else: + assert dist.get_world_size() > 1, 'SyncBatchNorm requires >1 world size' + B, C = x.shape[0], x.shape[1] + mean = torch.mean(x, dim=[0, 2, 3]) + meansqr = torch.mean(x * x, dim=[0, 2, 3]) + assert B > 0, 'does not support zero batch size' + vec = torch.cat([mean, meansqr], dim=0) + vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) + mean, meansqr = torch.split(vec, C) + + var = meansqr - mean * mean + invstd = torch.rsqrt(var + bn.eps) + scale = bn.weight[:feature_dim] * invstd + bias = bn.bias[:feature_dim] - mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + return x * scale + bias + + + #if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS: + # return bn(x) + #else: + # exponential_average_factor = 0.0 + + # if bn.training and bn.track_running_stats: + # # TODO: if statement only here to tell the jit to skip emitting this when it is None + # if bn.num_batches_tracked is not None: + # bn.num_batches_tracked += 1 + # if bn.momentum is None: # use cumulative moving average + # exponential_average_factor = 1.0 / float(bn.num_batches_tracked) + # else: # use exponential moving average + # exponential_average_factor = bn.momentum + # return F.batch_norm( + # x, bn.running_mean[:feature_dim], bn.running_var[:feature_dim], bn.weight[:feature_dim], + # bn.bias[:feature_dim], bn.training or not bn.track_running_stats, + # exponential_average_factor, bn.eps, + # ) + + diff --git a/xnas/spaces/BigNAS/ops.py b/xnas/spaces/BigNAS/ops.py new file mode 100644 index 0000000..10e55e9 --- /dev/null +++ b/xnas/spaces/BigNAS/ops.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn + +from collections import OrderedDict +from xnas.spaces.OFA.ops import SEModule, build_activation +from xnas.spaces.OFA.utils import ( + get_same_padding, +) + +class MBConvLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + expand_ratio=6, + mid_channels=None, + act_func="relu6", + use_se=False, + channels_per_group=1, + ): + super(MBConvLayer, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + self.kernel_size = kernel_size + self.stride = stride + self.expand_ratio = expand_ratio + self.mid_channels = mid_channels + self.act_func = act_func + self.use_se = use_se + self.channels_per_group = channels_per_group + + if self.mid_channels is None: + feature_dim = round(self.in_channels * self.expand_ratio) + else: + feature_dim = self.mid_channels + + if self.expand_ratio == 1: + self.inverted_bottleneck = None + else: + self.inverted_bottleneck = nn.Sequential(OrderedDict([ + ("conv", nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)), + ("bn", nn.BatchNorm2d(feature_dim)), + ("act", build_activation(self.act_func, inplace=True)), + ])) + + assert feature_dim % self.channels_per_group == 0 + active_groups = feature_dim // self.channels_per_group + pad = get_same_padding(self.kernel_size) + + # assert feature_dim % self.groups == 0 + # active_groups = feature_dim // self.groups + depth_conv_modules = [ + ( + "conv", + nn.Conv2d( + feature_dim, + feature_dim, + kernel_size, + stride, + pad, + groups=active_groups, + bias=False, + ), + ), + ("bn", nn.BatchNorm2d(feature_dim)), + ("act", build_activation(self.act_func, inplace=True)), + ] + if self.use_se: + depth_conv_modules.append(("se", SEModule(feature_dim))) + self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules)) + + self.point_linear = nn.Sequential( + OrderedDict( + [ + ("conv", nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)), + ("bn", nn.BatchNorm2d(out_channels)), + ] + ) + ) + + def forward(self, x): + if self.inverted_bottleneck: + x = self.inverted_bottleneck(x) + x = self.depth_conv(x) + x = self.point_linear(x) + return x + + @property + def module_str(self): + if self.mid_channels is None: + expand_ratio = self.expand_ratio + else: + expand_ratio = self.mid_channels // self.in_channels + layer_str = "%dx%d_MBConv%d_%s" % ( + self.kernel_size, + self.kernel_size, + expand_ratio, + self.act_func.upper(), + ) + if self.use_se: + layer_str = "SE_" + layer_str + layer_str += "_O%d" % self.out_channels + if self.groups is not None: + layer_str += "_G%d" % self.groups + if isinstance(self.point_linear.bn, nn.GroupNorm): + layer_str += "_GN%d" % self.point_linear.bn.num_groups + elif isinstance(self.point_linear.bn, nn.BatchNorm2d): + layer_str += "_BN" + + return layer_str + + @property + def config(self): + return { + "name": MBConvLayer.__name__, + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "kernel_size": self.kernel_size, + "stride": self.stride, + "expand_ratio": self.expand_ratio, + "mid_channels": self.mid_channels, + "act_func": self.act_func, + "use_se": self.use_se, + "groups": self.groups, + } + + @staticmethod + def build_from_config(config): + return MBConvLayer(**config) diff --git a/xnas/spaces/BigNAS/utils.py b/xnas/spaces/BigNAS/utils.py new file mode 100644 index 0000000..c22109c --- /dev/null +++ b/xnas/spaces/BigNAS/utils.py @@ -0,0 +1,134 @@ +# Implementation adapted from attentiveNAS - https://github.com/facebookresearch/AttentiveNAS + +import torch +import torch.nn as nn +import copy +import math + +multiply_adds = 1 + + +def count_convNd(m, _, y): + cin = m.in_channels + + kernel_ops = m.weight.size()[2] * m.weight.size()[3] + ops_per_element = kernel_ops + output_elements = y.nelement() + + # cout x oW x oH + total_ops = cin * output_elements * ops_per_element // m.groups + m.total_ops = torch.Tensor([int(total_ops)]) + + +def count_linear(m, _, __): + total_ops = m.in_features * m.out_features + + m.total_ops = torch.Tensor([int(total_ops)]) + + +register_hooks = { + nn.Conv1d: count_convNd, + nn.Conv2d: count_convNd, + nn.Conv3d: count_convNd, + ###################################### + nn.Linear: count_linear, + ###################################### + nn.Dropout: None, + nn.Dropout2d: None, + nn.Dropout3d: None, + nn.BatchNorm2d: None, +} + + +def profile(model, input_size=(1, 3, 224, 224), custom_ops=None): + handler_collection = [] + custom_ops = {} if custom_ops is None else custom_ops + + def add_hooks(m_): + if len(list(m_.children())) > 0: + return + + m_.register_buffer('total_ops', torch.zeros(1)) + m_.register_buffer('total_params', torch.zeros(1)) + + for p in m_.parameters(): + m_.total_params += torch.Tensor([p.numel()]) + + m_type = type(m_) + fn = None + + if m_type in custom_ops: + fn = custom_ops[m_type] + elif m_type in register_hooks: + fn = register_hooks[m_type] + else: + # print("Not implemented for ", m_) + pass + + if fn is not None: + # print("Register FLOP counter for module %s" % str(m_)) + _handler = m_.register_forward_hook(fn) + handler_collection.append(_handler) + + original_device = model.parameters().__next__().device + training = model.training + + model.eval() + model.apply(add_hooks) + + x = torch.zeros(input_size).to(original_device) + with torch.no_grad(): + model(x) + + total_ops = 0 + total_params = 0 + for m in model.modules(): + if len(list(m.children())) > 0: # skip for non-leaf module + continue + total_ops += m.total_ops + total_params += m.total_params + + total_ops = total_ops.item() + total_params = total_params.item() + + model.train(training) + model.to(original_device) + + for handler in handler_collection: + handler.remove() + + return total_ops, total_params + + +def count_net_flops_and_params(net, data_shape=(1, 3, 224, 224)): + if isinstance(net, nn.DataParallel): + net = net.module + + net = copy.deepcopy(net) + flop, nparams = profile(net, data_shape) + return flop /1e6, nparams /1e6 + + +def init_model(self, model_init="he_fout"): + """ Conv2d, BatchNorm2d, BatchNorm1d, Linear, """ + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if model_init == 'he_fout': + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif model_init == 'he_fin': + n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + else: + raise NotImplementedError + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + if m.affine: + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + stdv = 1. / math.sqrt(m.weight.size(1)) + m.weight.data.uniform_(-stdv, stdv) + if m.bias is not None: + m.bias.data.zero_() \ No newline at end of file diff --git a/xnas/spaces/OFA/dynamic_ops.py b/xnas/spaces/OFA/dynamic_ops.py index 582cb6a..74d4ba5 100644 --- a/xnas/spaces/OFA/dynamic_ops.py +++ b/xnas/spaces/OFA/dynamic_ops.py @@ -488,36 +488,18 @@ class DynamicMBConvLayer(nn.Module): self.use_se = use_se # build modules - max_middle_channel = make_divisible( - round(max(self.in_channel_list) * max(self.expand_ratio_list))) + max_middle_channel = make_divisible(round(max(self.in_channel_list) * max(self.expand_ratio_list))) if max(self.expand_ratio_list) == 1: self.inverted_bottleneck = None else: - self.inverted_bottleneck = nn.Sequential( - OrderedDict( - [ - ( - "conv", - DynamicConv2d( - max(self.in_channel_list), max_middle_channel - ), - ), - ("bn", DynamicBatchNorm2d(max_middle_channel)), - ("act", build_activation(self.act_func)), - ] - ) - ) + self.inverted_bottleneck = nn.Sequential(OrderedDict([ + ("conv", DynamicConv2d(max(self.in_channel_list), max_middle_channel)), + ("bn", DynamicBatchNorm2d(max_middle_channel)), + ("act", build_activation(self.act_func)), + ])) - self.depth_conv = nn.Sequential( - OrderedDict( - [ - ( - "conv", - DynamicSeparableConv2d( - max_middle_channel, self.kernel_size_list, self.stride, - kernel_trans=kernel_trans - ), - ), + self.depth_conv = nn.Sequential(OrderedDict([ + ("conv", DynamicSeparableConv2d(max_middle_channel, self.kernel_size_list, stride=self.stride, kernel_trans=kernel_trans)), ("bn", DynamicBatchNorm2d(max_middle_channel)), ("act", build_activation(self.act_func)), ] diff --git a/xnas/spaces/OFA/ops.py b/xnas/spaces/OFA/ops.py index ca6fbc8..cc880d8 100644 --- a/xnas/spaces/OFA/ops.py +++ b/xnas/spaces/OFA/ops.py @@ -7,6 +7,7 @@ from xnas.spaces.OFA.utils import ( min_divisible_value, get_same_padding, make_divisible, + drop_connect, ) @@ -46,12 +47,32 @@ def build_activation(act_func, inplace=True): return Hswish(inplace=inplace) elif act_func == "h_sigmoid": return Hsigmoid(inplace=inplace) + elif act_func == 'swish': + return MemoryEfficientSwish() elif act_func is None or act_func == "none": return None else: raise ValueError("do not support: %s" % act_func) +class SwishImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_tensors[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + +class MemoryEfficientSwish(nn.Module): + def forward(self, x): + return SwishImplementation.apply(x) + + class Hswish(nn.Module): def __init__(self, inplace=True): super(Hswish, self).__init__() @@ -637,27 +658,20 @@ class MBConvLayer(nn.Module): if self.expand_ratio == 1: self.inverted_bottleneck = None else: - self.inverted_bottleneck = nn.Sequential( - OrderedDict( - [ - ( - "conv", - nn.Conv2d( - self.in_channels, feature_dim, 1, 1, 0, bias=False - ), - ), - ("bn", nn.BatchNorm2d(feature_dim)), - ("act", build_activation(self.act_func, inplace=True)), - ] - ) - ) + self.inverted_bottleneck = nn.Sequential(OrderedDict([ + ("conv", nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)), + ("bn", nn.BatchNorm2d(feature_dim)), + ("act", build_activation(self.act_func, inplace=True)), + ])) pad = get_same_padding(self.kernel_size) - groups = ( + active_groups = ( feature_dim if self.groups is None else min_divisible_value(feature_dim, self.groups) ) + # assert feature_dim % self.groups == 0 + # active_groups = feature_dim // self.groups depth_conv_modules = [ ( "conv", @@ -667,7 +681,7 @@ class MBConvLayer(nn.Module): kernel_size, stride, pad, - groups=groups, + groups=active_groups, bias=False, ), ), @@ -739,19 +753,26 @@ class MBConvLayer(nn.Module): class ResidualBlock(nn.Module): - def __init__(self, conv, shortcut): + def __init__(self, conv, shortcut, drop_connect_rate=0): super(ResidualBlock, self).__init__() self.conv = conv + self.mobile_inverted_conv = self.conv # BigNAS self.shortcut = shortcut + self.drop_connect_rate = drop_connect_rate def forward(self, x): + in_channel = x.size(1) if self.conv is None or isinstance(self.conv, ZeroLayer): res = x elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer): res = self.conv(x) else: - res = self.conv(x) + self.shortcut(x) + im = self.shortcut(x) + x = self.conv(x) + if self.drop_connect_rate > 0 and in_channel == im.size(1) and self.shortcut.reduction == 1: + x = drop_connect(x, p=self.drop_connect_rate, training=self.training) + res = x + im return res @property @@ -955,3 +976,52 @@ class ResNetBottleneckBlock(nn.Module): @staticmethod def build_from_config(config): return ResNetBottleneckBlock(**config) + + +class ShortcutLayer(nn.Module): + """ + NOTE: + This class implements similar functionality to `IdentityLayer`, + but adds and removes part of the implementation. + """ + + def __init__(self, in_channels, out_channels, reduction=1): + super(ShortcutLayer, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.reduction = reduction + + self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False) + + def forward(self, x): + if self.reduction > 1: + padding = 0 if x.size(-1) % 2 == 0 else 1 + x = F.avg_pool2d(x, self.reduction, padding=padding) + if self.in_channels != self.out_channels: + x = self.conv(x) + return x + + @property + def module_str(self): + if self.in_channels == self.out_channels and self.reduction == 1: + conv_str = 'IdentityShortcut' + else: + if self.reduction == 1: + conv_str = '%d-%d_Shortcut' % (self.in_channels, self.out_channels) + else: + conv_str = '%d-%d_R%d_Shortcut' % (self.in_channels, self.out_channels, self.reduction) + return conv_str + + @property + def config(self): + return { + 'name': ShortcutLayer.__name__, + 'in_channels': self.in_channels, + 'out_channels': self.out_channels, + 'reduction': self.reduction, + } + + @staticmethod + def build_from_config(config): + return ShortcutLayer(**config) diff --git a/xnas/spaces/OFA/utils.py b/xnas/spaces/OFA/utils.py index be81354..8160273 100644 --- a/xnas/spaces/OFA/utils.py +++ b/xnas/spaces/OFA/utils.py @@ -47,6 +47,7 @@ def get_same_padding(kernel_size): assert kernel_size % 2 > 0, "kernel size should be odd number" return kernel_size // 2 + def make_divisible(v, divisor=8, min_val=None): """ This function is taken from the original tf repo. @@ -67,6 +68,30 @@ def make_divisible(v, divisor=8, min_val=None): return new_v +def drop_connect(inputs, p, training): + """Drop connect. + Args: + input (tensor: BCWH): Input of this structure. + p (float: 0.0~1.0): Probability of drop connection. + training (bool): The running mode. + Returns: + output: Output after drop connection. + """ + assert 0 <= p <= 1, 'p must be in range of [0,1]' + if not training: + return inputs + batch_size = inputs.shape[0] + keep_prob = 1.0 - p + + # generate binary_tensor mask according to probability (p for 0, 1-p for 1) + random_tensor = keep_prob + random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) + binary_tensor = torch.floor(random_tensor) + + output = inputs / keep_prob * binary_tensor + return output + + """ BN related """ def clean_num_batch_tracked(net): -- 2.34.1 From 7986bc44f32b33138437dacdeeb15e0a62040647 Mon Sep 17 00:00:00 2001 From: xfey Date: Wed, 22 Jun 2022 21:00:41 +0800 Subject: [PATCH 6/7] add AttentiveNAS and fix bugs --- configs/search/AttentiveNAS/eval.yaml | 132 ++++ configs/search/AttentiveNAS/train.yaml | 100 +++ scripts/search/AttentiveNAS/train_supernet.py | 271 ++++++++ scripts/search/BigNAS/search.py | 190 +++++ scripts/search/BigNAS/train_supernet.py | 262 +++++++ xnas/algorithms/AttentiveNAS/sampler.py | 117 ++++ xnas/algorithms/RMINAS/utils/random_data.py | 6 +- xnas/core/builder.py | 10 +- xnas/spaces/AttentiveNAS/cnn.py | 652 ++++++++++++++++++ 9 files changed, 1735 insertions(+), 5 deletions(-) create mode 100644 configs/search/AttentiveNAS/eval.yaml create mode 100644 configs/search/AttentiveNAS/train.yaml create mode 100644 scripts/search/AttentiveNAS/train_supernet.py create mode 100644 scripts/search/BigNAS/search.py create mode 100644 scripts/search/BigNAS/train_supernet.py create mode 100644 xnas/algorithms/AttentiveNAS/sampler.py create mode 100644 xnas/spaces/AttentiveNAS/cnn.py diff --git a/configs/search/AttentiveNAS/eval.yaml b/configs/search/AttentiveNAS/eval.yaml new file mode 100644 index 0000000..e3286ec --- /dev/null +++ b/configs/search/AttentiveNAS/eval.yaml @@ -0,0 +1,132 @@ +NUM_GPUS: 4 +RNG_SEED: 2 +SPACE: + NAME: 'attentivenas' +LOADER: + DATASET: 'imagenet' + NUM_CLASSES: 1000 + BATCH_SIZE: 256 + NUM_WORKERS: 4 + USE_VAL: True + TRANSFORM: "auto_augment_tf" +SEARCH: + IM_SIZE: 224 +ATTENTIVENAS: + BN_MOMENTUM: 0.1 + BN_EPS: 1.e-5 + POST_BN_CALIBRATION_BATCH_NUM: 64 + ACTIVE_SUBNET: # chosen from following settings + # attentive_nas_a0 + RESOLUTION: 192 + WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792] + KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3] + EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6] + DEPTH: [1, 3, 3, 3, 3, 3, 1] + + # # attentive_nas_a1 + # RESOLUTION: 224 + # WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1984] + # KERNEL_SIZE: [3, 3, 3, 5, 3, 5, 3] + # EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6] + # DEPTH: [1, 3, 3, 3, 3, 3, 1] + + # # attentive_nas_a2 + # RESOLUTION: 224 + # WIDTH: [16, 16, 24, 32, 64, 112, 200, 224, 1984] + # KERNEL_SIZE: [3, 3, 3, 3, 3, 5, 3] + # EXPAND_RATIO: [1, 4, 5, 4, 4, 6, 6] + # DEPTH: [1, 3, 3, 3, 3, 4, 1] + + # # attentive_nas_a3 + # RESOLUTION: 224 + # WIDTH: [16, 16, 24, 32, 64, 112, 208, 224, 1984] + # KERNEL_SIZE: [3, 3, 3, 5, 3, 3, 3] + # EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6] + # DEPTH: [2, 3, 3, 4, 3, 5, 1] + + # # attentive_nas_a4 + # RESOLUTION: 256 + # WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1984] + # KERNEL_SIZE: [3, 3, 3, 5, 3, 5, 3] + # EXPAND_RATIO: [1, 4, 4, 5, 4, 6, 6] + # DEPTH: [1, 3, 3, 4, 3, 5, 1] + + # # attentive_nas_a5 + # RESOLUTION: 256 + # WIDTH: [16, 16, 24, 32, 72, 112, 192, 216, 1792] + # KERNEL_SIZE: [3, 3, 3, 5, 3, 3, 3] + # EXPAND_RATIO: [1, 4, 5, 4, 4, 6, 6] + # DEPTH: [1, 3, 3, 3, 4, 6, 1] + + # # attentive_nas_a6 + # RESOLUTION: 288 + # WIDTH: [16, 16, 24, 32, 64, 112, 216, 224, 1984] + # KERNEL_SIZE: [3, 3, 3, 3, 3, 5, 3] + # EXPAND_RATIO: [1, 4, 6, 5, 4, 6, 6] + # DEPTH: [1, 3, 3, 4, 4, 6, 1] + SUPERNET_CFG: + use_v3_head: True + resolutions: [192, 224, 256, 288] + first_conv: + c: [16, 24] + act_func: 'swish' + s: 2 + mb1: + c: [16, 24] + d: [1, 2] + k: [3, 5] + t: [1] + s: 1 + act_func: 'swish' + se: False + mb2: + c: [24, 32] + d: [3, 4, 5] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: False + mb3: + c: [32, 40] + d: [3, 4, 5, 6] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: True + mb4: + c: [64, 72] + d: [3, 4, 5, 6] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: False + mb5: + c: [112, 120, 128] + d: [3, 4, 5, 6, 7, 8] + k: [3, 5] + t: [4, 5, 6] + s: 1 + act_func: 'swish' + se: True + mb6: + c: [192, 200, 208, 216] + d: [3, 4, 5, 6, 7, 8] + k: [3, 5] + t: [6] + s: 2 + act_func: 'swish' + se: True + mb7: + c: [216, 224] + d: [1, 2] + k: [3, 5] + t: [6] + s: 1 + act_func: 'swish' + se: True + last_conv: + c: [1792, 1984] + act_func: 'swish' diff --git a/configs/search/AttentiveNAS/train.yaml b/configs/search/AttentiveNAS/train.yaml new file mode 100644 index 0000000..16cf5a8 --- /dev/null +++ b/configs/search/AttentiveNAS/train.yaml @@ -0,0 +1,100 @@ +NUM_GPUS: 4 +RNG_SEED: 0 +SPACE: + NAME: 'attentivenas' +LOADER: + DATASET: 'imagenet' + NUM_CLASSES: 1000 + BATCH_SIZE: 64 # 32*8 in total + NUM_WORKERS: 4 + USE_VAL: True + TRANSFORM: "auto_augment_tf" +OPTIM: + GRAD_CLIP: 1. + WARMUP_EPOCH: 5 + MAX_EPOCH: 360 + WEIGHT_DECAY: 1.e-5 + BASE_LR: 0.2 + NESTEROV: True +SEARCH: + LOSS_FUN: "cross_entropy_smooth" + LABEL_SMOOTH: 0.1 +TRAIN: + DROP_PATH_PROB: 0.2 +ATTENTIVENAS: + SANDWICH_NUM: 4 # max + 2*middle + min + DROP_CONNECT: 0.2 + BN_MOMENTUM: 0. + BN_EPS: 1.e-5 + POST_BN_CALIBRATION_BATCH_NUM: 64 + SAMPLER: + METHOD: 'bestup' + MAP_PATH: 'xnas/algorithms/AttentiveNAS/flops_archs_off_table.map' + DISCRETIZE_STEP: 25 + NUM_TRIALS: 3 + SUPERNET_CFG: + use_v3_head: True + resolutions: [192, 224, 256, 288] + first_conv: + c: [16, 24] + act_func: 'swish' + s: 2 + mb1: + c: [16, 24] + d: [1, 2] + k: [3, 5] + t: [1] + s: 1 + act_func: 'swish' + se: False + mb2: + c: [24, 32] + d: [3, 4, 5] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: False + mb3: + c: [32, 40] + d: [3, 4, 5, 6] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: True + mb4: + c: [64, 72] + d: [3, 4, 5, 6] + k: [3, 5] + t: [4, 5, 6] + s: 2 + act_func: 'swish' + se: False + mb5: + c: [112, 120, 128] + d: [3, 4, 5, 6, 7, 8] + k: [3, 5] + t: [4, 5, 6] + s: 1 + act_func: 'swish' + se: True + mb6: + c: [192, 200, 208, 216] + d: [3, 4, 5, 6, 7, 8] + k: [3, 5] + t: [6] + s: 2 + act_func: 'swish' + se: True + mb7: + c: [216, 224] + d: [1, 2] + k: [3, 5] + t: [6] + s: 1 + act_func: 'swish' + se: True + last_conv: + c: [1792, 1984] + act_func: 'swish' \ No newline at end of file diff --git a/scripts/search/AttentiveNAS/train_supernet.py b/scripts/search/AttentiveNAS/train_supernet.py new file mode 100644 index 0000000..b6e7dfc --- /dev/null +++ b/scripts/search/AttentiveNAS/train_supernet.py @@ -0,0 +1,271 @@ +"""AttentiveNAS supernet training""" + +import os +import random +import operator + +import torch +import torch.nn as nn + +import xnas.core.config as config +from xnas.datasets.loader import get_normal_dataloader +import xnas.logger.meter as meter +import xnas.logger.logging as logging +from xnas.core.config import cfg +from xnas.core.builder import * + +# DDP +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +# AttentiveNAS +from xnas.runner.trainer import Trainer +from xnas.runner.scheduler import adjust_learning_rate_per_batch +from xnas.algorithms.AttentiveNAS.sampler import ArchSampler +from xnas.spaces.OFA.utils import list_mean +from xnas.spaces.BigNAS.utils import init_model + + +# Load config and check +config.load_configs() +logger = logging.get_logger(__name__) + + +def main(local_rank, world_size): + setup_env() + torch.cuda.set_device(local_rank) + dist.init_process_group(backend='nccl', rank=local_rank, world_size=world_size) + # Network + net = space_builder().to(local_rank) + init_model(net) + # Loss function + criterion = criterion_builder() + soft_criterion = criterion_builder('kl_soft') + + # Data loaders + [train_loader, valid_loader] = get_normal_dataloader() + + # Optimizers + net_params = [ + # parameters with weight decay + {"params": net.get_parameters(['bn', 'bias'], mode="exclude"), "weight_decay": cfg.OPTIM.WEIGHT_DECAY}, + # parameters without weight decay + {"params": net.get_parameters(['bn', 'bias'], mode="include") , "weight_decay": 0}, + ] + optimizer = optimizer_builder("SGD", net_params) + + # sampler for AttentiveNAS + sampler = ArchSampler( + cfg.ATTENTIVENAS.SAMPLER.MAP_PATH, cfg.ATTENTIVENAS.SAMPLER.DISCRETIZE_STEP, net, None + ) + + net = DDP(net, device_ids=[local_rank], find_unused_parameters=True) + + # Initialize Recorder + attnas_trainer = AttentivenasTrainer( + model=net, + criterion=criterion, + soft_criterion=soft_criterion, + sampler=sampler, + optimizer=optimizer, + lr_scheduler=None, + train_loader=train_loader, + test_loader=valid_loader, + ) + + # Resume + start_epoch = attnas_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0 + + # Training + logger.info("Start AttentiveNAS training.") + dist.barrier() + attnas_trainer.start() + for cur_epoch in range(start_epoch, cfg.OPTIM.WARMUP_EPOCH+cfg.OPTIM.MAX_EPOCH): + attnas_trainer.train_epoch(cur_epoch, rank=local_rank) + if local_rank == 0: + if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH: + attnas_trainer.validate(cur_epoch, local_rank) + attnas_trainer.finish() + dist.barrier() + torch.cuda.empty_cache() + + +class AttentivenasTrainer(Trainer): + """Trainer for AttentiveNAS.""" + def __init__(self, model, criterion, soft_criterion, sampler, optimizer, lr_scheduler, train_loader, test_loader): + super().__init__(model, criterion, optimizer, lr_scheduler, train_loader, test_loader) + self.sandwich_sample_num = max(2, cfg.ATTENTIVENAS.SANDWICH_NUM) # containing max & min + self.soft_criterion = soft_criterion + self.sampler = sampler + + def train_epoch(self, cur_epoch, rank=0): + self.model.train() + # lr = self.lr_scheduler.get_last_lr()[0] + cur_step = cur_epoch * len(self.train_loader) + # self.writer.add_scalar('train/lr', lr, cur_step) + self.train_meter.iter_tic() + self.train_loader.sampler.set_epoch(cur_epoch) # DDP + for cur_iter, (inputs, labels) in enumerate(self.train_loader): + inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True) + + # Adjust lr per iter + cur_lr = adjust_learning_rate_per_batch( + epoch=cur_epoch, + n_iter=len(self.train_loader), + iter=cur_iter, + warmup=(cur_epoch < cfg.OPTIM.WARMUP_EPOCH), + ) + for param_group in self.optimizer.param_groups: + param_group["lr"] = cur_lr + # self.writer.add_scalar('train/lr', cur_lr, cur_step) + + self.optimizer.zero_grad() + + ## Sandwich Rule ## + # Step 1. Largest network sampling & regularization + self.model.module.sample_max_subnet() + self.model.module.set_dropout_rate(cfg.TRAIN.DROP_PATH_PROB, cfg.ATTENTIVENAS.DROP_CONNECT) + preds = self.model(inputs) + loss = self.criterion(preds, labels) + loss.backward() + + with torch.no_grad(): + soft_logits = preds.clone().detach() + + # Step 2. sample smaller networks + self.model.module.set_dropout_rate(0, 0) + for arch_id in range(1, self.sandwich_sample_num): + if arch_id == self.sandwich_sample_num - 1: + self.model.module.sample_min_subnet() + else: + if self.sampler is not None: + sampling_method = cfg.ATTENTIVENAS.SAMPLER.METHOD + if sampling_method in ['bestup', 'worstup']: + target_flops = self.sampler.sample_one_target_flops() + candidate_archs = self.sampler.sample_archs_according_to_flops( + target_flops, n_samples=cfg.ATTENTIVENAS.SAMPLER.NUM_TRIALS + ) + my_pred_accs = [] + for arch in candidate_archs: + self.model.module.set_active_subnet(**arch) + with torch.no_grad(): + my_pred_accs.append(-1.0 * self.criterion(self.model(inputs), labels)) + if sampling_method == 'bestup': + idx, _ = max(enumerate(my_pred_accs), key=operator.itemgetter(1)) + else: + idx, _ = min(enumerate(my_pred_accs), key=operator.itemgetter(1)) + self.model.module.set_active_subnet(**candidate_archs[idx]) #reset + else: + subnet_seed = int("%d%.3d%.3d" % (cur_step, arch_id, 0)) + random.seed(subnet_seed) + self.model.module.sample_active_subnet() + + preds = self.model(inputs) + + if self.soft_criterion is not None: + loss = self.soft_criterion(preds, soft_logits) + else: + loss = self.criterion(preds, labels) + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), cfg.OPTIM.GRAD_CLIP) + self.optimizer.step() + + # calculating errors. The source code of AttentiveNAS uses statistics of the smallest network and XNAS follows. + top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) + loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item() + self.train_meter.iter_toc() + self.train_meter.update_stats(top1_err, top5_err, loss, cur_lr, inputs.size(0) * cfg.NUM_GPUS) + self.train_meter.log_iter_stats(cur_epoch, cur_iter) + self.train_meter.iter_tic() + # self.writer.add_scalar('train/loss', i_loss, cur_step) + # self.writer.add_scalar('train/top1_error', i_top1err, cur_step) + # self.writer.add_scalar('train/top5_error', i_top5err, cur_step) + cur_step += 1 + # Log epoch stats + self.train_meter.log_epoch_stats(cur_epoch) + self.train_meter.reset() + # self.lr_scheduler.step() + # Saving checkpoint + if rank==0 and (cur_epoch + 1) % cfg.SAVE_PERIOD == 0: + self.saving(cur_epoch) + + @torch.no_grad() + def test_epoch(self, subnet, cur_epoch, rank=0): + subnet.eval() + self.test_meter.reset(True) + self.test_meter.iter_tic() + for cur_iter, (inputs, labels) in enumerate(self.test_loader): + inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True) + preds = subnet(inputs) + top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) + top1_err, top5_err = top1_err.item(), top5_err.item() + + self.test_meter.iter_toc() + self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) + self.test_meter.log_iter_stats(cur_epoch, cur_iter) + self.test_meter.iter_tic() + top1_err = self.test_meter.mb_top1_err.get_win_avg() + top5_err = self.test_meter.mb_top5_err.get_win_avg() + # self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch) + # self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch) + # Log epoch stats + self.test_meter.log_epoch_stats(cur_epoch) + # self.test_meter.reset() + return top1_err, top5_err + + + def validate(self, cur_epoch, rank, bn_calibration=True): + subnets_to_be_evaluated = { + 'attentive_nas_min_net': {}, + 'attentive_nas_max_net': {}, + } + + top1_list, top5_list = [], [] + with torch.no_grad(): + for net_id in subnets_to_be_evaluated: + if net_id == 'attentive_nas_min_net': + self.model.module.sample_min_subnet() + elif net_id == 'attentive_nas_max_net': + self.model.module.sample_max_subnet() + elif net_id.startswith('attentive_nas_random_net'): + self.model.module.sample_active_subnet() + else: + self.model.module.set_active_subnet( + subnets_to_be_evaluated[net_id]['resolution'], + subnets_to_be_evaluated[net_id]['width'], + subnets_to_be_evaluated[net_id]['depth'], + subnets_to_be_evaluated[net_id]['kernel_size'], + subnets_to_be_evaluated[net_id]['expand_ratio'], + ) + + subnet = self.model.module.get_active_subnet() + subnet.to(rank) + logger.info("evaluating subnet {}".format(net_id)) + + if bn_calibration: + subnet.eval() + logger.info("Calibrating BN running statistics.") + subnet.reset_running_stats_for_calibration() + for cur_iter, (inputs, _) in enumerate(self.train_loader): + if cur_iter >= cfg.ATTENTIVENAS.POST_BN_CALIBRATION_BATCH_NUM: + break + inputs = inputs.to(rank) + subnet(inputs) # forward only + + top1_err, top5_err = self.test_epoch(subnet, cur_epoch, rank) + top1_list.append(top1_err), top5_list.append(top5_err) + logger.info("Average@all_subnets top1_err:{} top5_err:{}".format(list_mean(top1_list), list_mean(top5_list))) + if self.best_err > list_mean(top1_list): + self.best_err = list_mean(top1_list) + self.saving(cur_epoch, best=True) + + +if __name__ == '__main__': + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '23333' + + if torch.cuda.is_available(): + cfg.NUM_GPUS = torch.cuda.device_count() + + mp.spawn(main, nprocs=cfg.NUM_GPUS, args=(cfg.NUM_GPUS,), join=True) diff --git a/scripts/search/BigNAS/search.py b/scripts/search/BigNAS/search.py new file mode 100644 index 0000000..4a52a01 --- /dev/null +++ b/scripts/search/BigNAS/search.py @@ -0,0 +1,190 @@ +"""BigNAS subnet searching: Coarse-to-fine Architecture Selection""" + +import numpy as np +from itertools import product + +import torch + +import xnas.core.config as config +import xnas.logger.meter as meter +import xnas.logger.logging as logging +from xnas.core.builder import * +from xnas.core.config import cfg +from xnas.datasets.loader import get_normal_dataloader +from xnas.logger.meter import TestMeter + + +# Load config and check +config.load_configs() +logger = logging.get_logger(__name__) + + +def get_all_subnets(): + # get all subnets + all_subnets = [] + subnet_sets = cfg.BIGNAS.SEARCH_CFG_SETS + stage_names = ['mb1', 'mb2', 'mb3', 'mb4', 'mb5', 'mb6', 'mb7'] + + mb_stage_subnets = [] + for mbstage in stage_names: + mb_block_cfg = getattr(subnet_sets, mbstage) + mb_stage_subnets.append(list(product( + mb_block_cfg.c, + mb_block_cfg.d, + mb_block_cfg.k, + mb_block_cfg.t + ))) + + all_mb_stage_subnets = list(product(*mb_stage_subnets)) + + resolutions = getattr(subnet_sets, 'resolutions') + first_conv = getattr(subnet_sets, 'first_conv') + last_conv = getattr(subnet_sets, 'last_conv') + + for res in resolutions: + for fc in first_conv.c: + for mb in all_mb_stage_subnets: + np_mb_choice = np.array(mb) + width = np_mb_choice[:, 0].tolist() # c + depth = np_mb_choice[:, 1].tolist() # d + kernel = np_mb_choice[:, 2].tolist() # k + expand = np_mb_choice[:, 3].tolist() # t + for lc in last_conv.c: + all_subnets.append({ + 'resolution': res, + 'width': [fc] + width + [lc], + 'depth': depth, + 'kernel_size': kernel, + 'expand_ratio': expand + }) + return all_subnets + + +def main(): + setup_env() + supernet = space_builder().cuda() + supernet.load_weights_from_pretrained_models(cfg.SEARCH.WEIGHTS) + + [train_loader, valid_loader] = get_normal_dataloader() + + test_meter = TestMeter(len(valid_loader)) + + all_subnets = get_all_subnets() + benchmarks = [] + + # Phase 1. coarse search + for k,subnet_cfg in enumerate(all_subnets): + supernet.set_active_subnet( + subnet_cfg['resolution'], + subnet_cfg['width'], + subnet_cfg['depth'], + subnet_cfg['kernel_size'], + subnet_cfg['expand_ratio'], + ) + subnet = supernet.get_active_subnet().cuda() + + # Validate + top1_err, top5_err = validate(subnet, train_loader, valid_loader, test_meter) + flops = supernet.compute_active_subnet_flops() + + logger.info("[{}/{}] flops:{} top1_err:{} top5_err:{}".format( + k+1, len(all_subnets), flops, top1_err, top5_err + )) + + benchmarks.append({ + 'subnet_cfg': subnet_cfg, + 'flops': flops, + 'top1_err': top1_err, + 'top5_err': top5_err + }) + + # Phase 2. fine-grained search + try: + best_subnet_info = list(filter( + lambda k: k['flops'] < cfg.BIGNAS.CONSTRAINT_FLOPS, + sorted(benchmarks, key=lambda d: d['top1_err'])))[0] + best_subnet_cfg = best_subnet_info['subnet_cfg'] + best_subnet_top1 = best_subnet_info['top1_err'] + except IndexError: + logger.info("Cannot find subnets under {} FLOPs".format(cfg.BIGNAS.CONSTRAINT_FLOPS)) + exit(1) + + for mutate_epoch in range(cfg.BIGNAS.NUM_MUTATE): + new_subnet_cfg = supernet.mutate_and_reset(best_subnet_cfg) + prev_cfgs = [i['subnet_cfg'] for i in benchmarks] + if new_subnet_cfg in prev_cfgs: + continue + + subnet = supernet.get_active_subnet().cuda() + # Validate + top1_err, top5_err = validate(subnet, train_loader, valid_loader, test_meter) + flops = supernet.compute_active_subnet_flops() + + logger.info("[{}/{}] flops:{} top1_err:{} top5_err:{}".format( + mutate_epoch+1, cfg.BIGNAS.NUM_MUTATE, flops, top1_err, top5_err + )) + + benchmarks.append({ + 'subnet_cfg': subnet_cfg, + 'flops': flops, + 'top1_err': top1_err, + 'top5_err': top5_err + }) + + if flops < cfg.BIGNAS.CONSTRAINT_FLOPS and top1_err < best_subnet_top1: + best_subnet_cfg = new_subnet_cfg + best_subnet_top1 = top1_err + + # Final best architecture + logger.info("="*20 + "\nMutate Finished.") + logger.info("Best Architecture:\n{}\n Best top1_err:{}".format( + best_subnet_cfg, best_subnet_top1 + )) + + +@torch.no_grad() +def validate(subnet, train_loader, valid_loader, test_meter): + # BN calibration + subnet.eval() + logger.info("Calibrating BN running statistics.") + subnet.reset_running_stats_for_calibration() + for cur_iter, (inputs, _) in enumerate(train_loader): + if cur_iter >= cfg.BIGNAS.POST_BN_CALIBRATION_BATCH_NUM: + break + inputs = inputs.cuda() + subnet(inputs) # forward only + + top1_err, top5_err = test_epoch(subnet, valid_loader, test_meter) + return top1_err, top5_err + + +def test_epoch(subnet, test_loader, test_meter): + subnet.eval() + test_meter.reset(True) + test_meter.iter_tic() + for cur_iter, (inputs, labels) in enumerate(test_loader): + # [debug] + if cur_iter > 20: + break + + inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) + preds = subnet(inputs) + top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) + top1_err, top5_err = top1_err.item(), top5_err.item() + + test_meter.iter_toc() + test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) + test_meter.log_iter_stats(0, cur_iter) + test_meter.iter_tic() + top1_err = test_meter.mb_top1_err.get_win_avg() + top5_err = test_meter.mb_top5_err.get_win_avg() + # self.writer.add_scalar('val/top1_error', test_meter.mb_top1_err.get_win_avg(), cur_epoch) + # self.writer.add_scalar('val/top5_error', test_meter.mb_top5_err.get_win_avg(), cur_epoch) + # Log epoch stats + test_meter.log_epoch_stats(0) + # test_meter.reset() + return top1_err, top5_err + + +if __name__ == "__main__": + main() diff --git a/scripts/search/BigNAS/train_supernet.py b/scripts/search/BigNAS/train_supernet.py new file mode 100644 index 0000000..bc968bd --- /dev/null +++ b/scripts/search/BigNAS/train_supernet.py @@ -0,0 +1,262 @@ +"""AttentiveNAS supernet training""" + +import os +import random + +import torch +import torch.nn as nn + +import xnas.core.config as config +from xnas.datasets.loader import get_normal_dataloader +import xnas.logger.meter as meter +import xnas.logger.logging as logging +from xnas.core.config import cfg +from xnas.core.builder import * + +# DDP +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +# AttentiveNAS +from xnas.runner.trainer import Trainer +from xnas.runner.scheduler import adjust_learning_rate_per_batch +from xnas.spaces.OFA.utils import list_mean +from xnas.spaces.BigNAS.utils import init_model + +# Load config and check +config.load_configs() +logger = logging.get_logger(__name__) + + +def main(local_rank, world_size): + setup_env() + torch.cuda.set_device(local_rank) + dist.init_process_group(backend='nccl', rank=local_rank, world_size=world_size) + # Network + net = space_builder().to(local_rank) + init_model(net) + # Loss function + criterion = criterion_builder() + soft_criterion = criterion_builder('kl_soft') + + # Data loaders + [train_loader, valid_loader] = get_normal_dataloader() + + # Optimizers + net_params = [ + # parameters with weight decay + {"params": net.get_parameters(['bn', 'bias'], mode="exclude"), "weight_decay": cfg.OPTIM.WEIGHT_DECAY}, + # parameters without weight decay + {"params": net.get_parameters(['bn', 'bias'], mode="include") , "weight_decay": 0}, + ] + optimizer = optimizer_builder("SGD", net_params) + # Rule: only regularize the biggest model + optimizer_no_wd = torch.optim.SGD( + net.parameters(), + cfg.OPTIM.BASE_LR, + cfg.OPTIM.MOMENTUM, + cfg.OPTIM.DAMPENING, + 0, # no weight decay. + cfg.OPTIM.NESTEROV, + ) + + net = DDP(net, device_ids=[local_rank], find_unused_parameters=True) + + # Initialize Recorder + bignas_trainer = BigNASTrainer( + model=net, + criterion=criterion, + soft_criterion=soft_criterion, + optimizer=optimizer, + optim_no_wd=optimizer_no_wd, + lr_scheduler=None, + train_loader=train_loader, + test_loader=valid_loader, + ) + + # Resume + start_epoch = bignas_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0 + + # Training + logger.info("Start BigNAS training.") + dist.barrier() + bignas_trainer.start() + for cur_epoch in range(start_epoch, cfg.OPTIM.WARMUP_EPOCH+cfg.OPTIM.MAX_EPOCH): + bignas_trainer.train_epoch(cur_epoch, rank=local_rank) + if local_rank == 0: + if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH: + bignas_trainer.validate(cur_epoch, local_rank) + bignas_trainer.finish() + dist.barrier() + torch.cuda.empty_cache() + + +class BigNASTrainer(Trainer): + """Trainer for BigNAS.""" + def __init__(self, model, criterion, soft_criterion, optimizer, optim_no_wd, lr_scheduler, train_loader, test_loader): + super().__init__(model, criterion, optimizer, lr_scheduler, train_loader, test_loader) + self.sandwich_sample_num = max(2, cfg.BIGNAS.SANDWICH_NUM) # containing max & min + self.soft_criterion = soft_criterion + self.optim_no_wd = optim_no_wd + + def train_epoch(self, cur_epoch, rank=0): + self.model.train() + # lr = self.lr_scheduler.get_last_lr()[0] + cur_step = cur_epoch * len(self.train_loader) + # self.writer.add_scalar('train/lr', lr, cur_step) + self.train_meter.iter_tic() + self.train_loader.sampler.set_epoch(cur_epoch) # DDP + for cur_iter, (inputs, labels) in enumerate(self.train_loader): + # [debug] + if cur_iter > 20: + break + + inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True) + + # Adjust lr per iter + cur_lr = adjust_learning_rate_per_batch( + epoch=cur_epoch, + n_iter=len(self.train_loader), + iter=cur_iter, + warmup=(cur_epoch < cfg.OPTIM.WARMUP_EPOCH), + ) + # Rule: constrant ending + cur_lr = max(cur_lr, 0.05 * cfg.OPTIM.BASE_LR) + for param_group in self.optimizer.param_groups: + param_group["lr"] = cur_lr + # self.writer.add_scalar('train/lr', cur_lr, cur_step) + + ## Sandwich Rule ## + # Step 1. Largest network sampling & regularization + self.optimizer.zero_grad() + self.model.module.sample_max_subnet() + self.model.module.set_dropout_rate(cfg.TRAIN.DROP_PATH_PROB, cfg.BIGNAS.DROP_CONNECT) + preds = self.model(inputs) + loss = self.criterion(preds, labels) + loss.backward() + self.optimizer.step() + with torch.no_grad(): + soft_logits = preds.clone().detach() + + # Step 2. sample smaller networks + self.optim_no_wd.zero_grad() + self.model.module.set_dropout_rate(0, 0) + for arch_id in range(1, self.sandwich_sample_num): + if arch_id == self.sandwich_sample_num - 1: + self.model.module.sample_min_subnet() + else: + subnet_seed = int("%d%.3d%.3d" % (cur_step, arch_id, 0)) + random.seed(subnet_seed) + self.model.module.sample_active_subnet() + preds = self.model(inputs) + if self.soft_criterion is not None: + loss = self.soft_criterion(preds, soft_logits) + else: + loss = self.criterion(preds, labels) + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), cfg.OPTIM.GRAD_CLIP) + self.optim_no_wd.step() + + # calculating errors. The source code of AttentiveNAS uses statistics of the smallest network and XNAS follows. + top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) + loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item() + self.train_meter.iter_toc() + self.train_meter.update_stats(top1_err, top5_err, loss, cur_lr, inputs.size(0) * cfg.NUM_GPUS) + self.train_meter.log_iter_stats(cur_epoch, cur_iter) + self.train_meter.iter_tic() + # self.writer.add_scalar('train/loss', i_loss, cur_step) + # self.writer.add_scalar('train/top1_error', i_top1err, cur_step) + # self.writer.add_scalar('train/top5_error', i_top5err, cur_step) + cur_step += 1 + # Log epoch stats + self.train_meter.log_epoch_stats(cur_epoch) + self.train_meter.reset() + # self.lr_scheduler.step() + # Saving checkpoint + if rank==0 and (cur_epoch + 1) % cfg.SAVE_PERIOD == 0: + self.saving(cur_epoch) + + @torch.no_grad() + def test_epoch(self, subnet, cur_epoch, rank=0): + subnet.eval() + self.test_meter.reset(True) + self.test_meter.iter_tic() + for cur_iter, (inputs, labels) in enumerate(self.test_loader): + # [debug] + if cur_iter > 20: + break + + inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True) + preds = subnet(inputs) + top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) + top1_err, top5_err = top1_err.item(), top5_err.item() + + self.test_meter.iter_toc() + self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) + self.test_meter.log_iter_stats(cur_epoch, cur_iter) + self.test_meter.iter_tic() + top1_err = self.test_meter.mb_top1_err.get_win_avg() + top5_err = self.test_meter.mb_top5_err.get_win_avg() + # self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch) + # self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch) + # Log epoch stats + self.test_meter.log_epoch_stats(cur_epoch) + # self.test_meter.reset() + return top1_err, top5_err + + + def validate(self, cur_epoch, rank, bn_calibration=True): + subnets_to_be_evaluated = { + 'bignas_min_net': {}, + 'bignas_max_net': {}, + } + + top1_list, top5_list = [], [] + with torch.no_grad(): + for net_id in subnets_to_be_evaluated: + if net_id == 'bignas_min_net': + self.model.module.sample_min_subnet() + elif net_id == 'bignas_max_net': + self.model.module.sample_max_subnet() + elif net_id.startswith('bignas_random_net'): + self.model.module.sample_active_subnet() + else: + self.model.module.set_active_subnet( + subnets_to_be_evaluated[net_id]['resolution'], + subnets_to_be_evaluated[net_id]['width'], + subnets_to_be_evaluated[net_id]['depth'], + subnets_to_be_evaluated[net_id]['kernel_size'], + subnets_to_be_evaluated[net_id]['expand_ratio'], + ) + + subnet = self.model.module.get_active_subnet() + subnet.to(rank) + logger.info("evaluating subnet {}".format(net_id)) + + if bn_calibration: + subnet.eval() + logger.info("Calibrating BN running statistics.") + subnet.reset_running_stats_for_calibration() + for cur_iter, (inputs, _) in enumerate(self.train_loader): + if cur_iter >= cfg.BIGNAS.POST_BN_CALIBRATION_BATCH_NUM: + break + inputs = inputs.to(rank) + subnet(inputs) # forward only + + top1_err, top5_err = self.test_epoch(subnet, cur_epoch, rank) + top1_list.append(top1_err), top5_list.append(top5_err) + logger.info("Average@all_subnets top1_err:{} top5_err:{}".format(list_mean(top1_list), list_mean(top5_list))) + if self.best_err > list_mean(top1_list): + self.best_err = list_mean(top1_list) + self.saving(cur_epoch, best=True) + + +if __name__ == '__main__': + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '23333' + + if torch.cuda.is_available(): + cfg.NUM_GPUS = torch.cuda.device_count() + + mp.spawn(main, nprocs=cfg.NUM_GPUS, args=(cfg.NUM_GPUS,), join=True) diff --git a/xnas/algorithms/AttentiveNAS/sampler.py b/xnas/algorithms/AttentiveNAS/sampler.py new file mode 100644 index 0000000..32baf8f --- /dev/null +++ b/xnas/algorithms/AttentiveNAS/sampler.py @@ -0,0 +1,117 @@ +import random + +def count_helper(v, flops, m): + if flops not in m: + m[flops] = {} + if v not in m[flops]: + m[flops][v] = 0 + m[flops][v] += 1 + + +def round_flops(flops, step): + return int(round(flops / step) * step) + + +def convert_count_to_prob(m): + if isinstance(m[list(m.keys())[0]], dict): + for k in m: + convert_count_to_prob(m[k]) + else: + t = sum(m.values()) + for k in m: + m[k] = 1.0 * m[k] / t + + +def sample_helper(flops, m): + keys = list(m[flops].keys()) + probs = list(m[flops].values()) + return random.choices(keys, weights=probs)[0] + + +def build_trasition_prob_matrix(file_handler, step): + # initlizie + prob_map = {} + prob_map['discretize_step'] = step + for k in ['flops', 'resolution', 'width', 'depth', 'kernel_size', 'expand_ratio']: + prob_map[k] = {} + + cc = 0 + for line in file_handler: + vals = eval(line.strip()) + + # discretize + flops = round_flops(vals['flops'], step) + prob_map['flops'][flops] = prob_map['flops'].get(flops, 0) + 1 + + # resolution + r = vals['resolution'] + count_helper(r, flops, prob_map['resolution']) + + for k in ['width', 'depth', 'kernel_size', 'expand_ratio']: + for idx, v in enumerate(vals[k]): + if idx not in prob_map[k]: + prob_map[k][idx] = {} + count_helper(v, flops, prob_map[k][idx]) + + cc += 1 + + # convert count to probability + for k in ['flops', 'resolution', 'width', 'depth', 'kernel_size', 'expand_ratio']: + convert_count_to_prob(prob_map[k]) + prob_map['n_observations'] = cc + return prob_map + + + +class ArchSampler(): + def __init__(self, arch_to_flops_map_file_path, discretize_step, model, acc_predictor=None): + super(ArchSampler, self).__init__() + + with open(arch_to_flops_map_file_path, 'r') as fp: + self.prob_map = build_trasition_prob_matrix(fp, discretize_step) + + self.discretize_step = discretize_step + self.model = model + + self.acc_predictor = acc_predictor + + self.min_flops = min(list(self.prob_map['flops'].keys())) + self.max_flops = max(list(self.prob_map['flops'].keys())) + + self.curr_sample_pool = None #TODO; architecture samples could be generated in an asynchronous way + + + def sample_one_target_flops(self, flops_uniform=False): + f_vals = list(self.prob_map['flops'].keys()) + f_probs = list(self.prob_map['flops'].values()) + + if flops_uniform: + return random.choice(f_vals) + else: + return random.choices(f_vals, weights=f_probs)[0] + + + def sample_archs_according_to_flops(self, target_flops, n_samples=1, max_trials=100, return_flops=True, return_trials=False): + archs = [] + #for _ in range(n_samples): + while len(archs) < n_samples: + for _trial in range(max_trials+1): + arch = {} + arch['resolution'] = sample_helper(target_flops, self.prob_map['resolution']) + for k in ['width', 'kernel_size', 'depth', 'expand_ratio']: + arch[k] = [] + for idx in sorted(list(self.prob_map[k].keys())): + arch[k].append(sample_helper(target_flops, self.prob_map[k][idx])) + if self.model: + self.model.set_active_subnet(**arch) + flops = self.model.compute_active_subnet_flops() + if return_flops: + arch['flops'] = flops + if round_flops(flops, self.discretize_step) == target_flops: + break + else: + raise NotImplementedError + #accepte the sample anyway + archs.append(arch) + return archs + diff --git a/xnas/algorithms/RMINAS/utils/random_data.py b/xnas/algorithms/RMINAS/utils/random_data.py index 2168754..8fd0f49 100644 --- a/xnas/algorithms/RMINAS/utils/random_data.py +++ b/xnas/algorithms/RMINAS/utils/random_data.py @@ -8,9 +8,9 @@ def get_random_data(batchsize, name): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') if name == 'imagenet': train_loader, _ = ImageFolder( - "./data/imagenet/ILSVRC2012_img_train/", - [0.5, 0.5], - batchsize*16, + datapath="./data/imagenet/ILSVRC2012_img_train/", + batch_size=batchsize*16, + split=[0.5, 0.5], ).generate_data_loader() else: train_loader, _ = get_normal_dataloader(name, batchsize*16) diff --git a/xnas/core/builder.py b/xnas/core/builder.py index ecc672b..8a69605 100644 --- a/xnas/core/builder.py +++ b/xnas/core/builder.py @@ -51,6 +51,8 @@ from xnas.spaces.DropNAS.cnn import _DropNASCNN from xnas.spaces.OFA.MobileNetV3.ofa_cnn import _OFAMobileNetV3 from xnas.spaces.OFA.ProxylessNet.ofa_cnn import _OFAProxylessNASNet from xnas.spaces.OFA.ResNets.ofa_cnn import _OFAResNet +from xnas.spaces.BigNAS.cnn import _BigNAS_CNN, _infer_BigNAS_CNN +from xnas.spaces.AttentiveNAS.cnn import _AttentiveNAS_CNN, _infer_AttentiveNAS_CNN from xnas.spaces.NASBenchMacro.cnn import _NBMacro_child_train, _NBMacro_sup_train SUPPORTED_SPACES = { @@ -63,15 +65,19 @@ SUPPORTED_SPACES = { "gdas_nb201": _GDAS_nb201_CNN, "dropnas": _DropNASCNN, "spos": _SPOS_CNN, + "spos_nb201": _SPOS_nb201_CNN, "nasbenchmacro": _NBMacro_sup_train, "ofa_mbv3": _OFAMobileNetV3, "ofa_proxyless": _OFAProxylessNASNet, "ofa_resnet": _OFAResNet, - # models for inference + "attentivenas": _AttentiveNAS_CNN, + "bignas": _BigNAS_CNN, + # ===== models for inference ===== "infer_darts": _infer_DartsCNN, "infer_nb201": _infer_NASBench201, "infer_spos": _infer_SPOS_CNN, - "spos_nb201": _SPOS_nb201_CNN, + "infer_attentivenas": _infer_AttentiveNAS_CNN, + # "infer_bignas": _infer_BigNAS_CNN, } diff --git a/xnas/spaces/AttentiveNAS/cnn.py b/xnas/spaces/AttentiveNAS/cnn.py new file mode 100644 index 0000000..e596b43 --- /dev/null +++ b/xnas/spaces/AttentiveNAS/cnn.py @@ -0,0 +1,652 @@ +# Implementation adapted from AttentiveNAS: https://github.com/facebookresearch/AttentiveNAS + +import random +from copy import deepcopy +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from xnas.spaces.OFA.ops import ResidualBlock +from xnas.spaces.OFA.dynamic_ops import DynamicLinearLayer +from xnas.spaces.OFA.utils import val2list, make_divisible +from xnas.spaces.BigNAS.dynamic_layers import DynamicMBConvLayer, DynamicConvLayer, DynamicShortcutLayer + + +class AttentiveNasStaticModel(nn.Module): + + def __init__(self, first_conv, blocks, last_conv, classifier, resolution, use_v3_head=True): + super(AttentiveNasStaticModel, self).__init__() + + self.first_conv = first_conv + self.blocks = nn.ModuleList(blocks) + self.last_conv = last_conv + self.classifier = classifier + + self.resolution = resolution #input size + self.use_v3_head = use_v3_head + + def forward(self, x): + # resize input to target resolution first + # Rule: transform images into different sizes + if x.size(-1) != self.resolution: + x = F.interpolate(x, size=self.resolution, mode='bicubic') + + x = self.first_conv(x) + for block in self.blocks: + x = block(x) + x = self.last_conv(x) + if not self.use_v3_head: + x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling + x = torch.squeeze(x) + x = self.classifier(x) + return x + + + @property + def module_str(self): + _str = self.first_conv.module_str + '\n' + for block in self.blocks: + _str += block.module_str + '\n' + #_str += self.last_conv.module_str + '\n' + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + 'name': AttentiveNasStaticModel.__name__, + 'bn': self.get_bn_param(), + 'first_conv': self.first_conv.config, + 'blocks': [ + block.config for block in self.blocks + ], + #'last_conv': self.last_conv.config, + 'classifier': self.classifier.config, + 'resolution': self.resolution + } + + + def weight_initialization(self): + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + @staticmethod + def build_from_config(config): + raise NotImplementedError + + def set_bn_param(self, momentum, eps): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm): + if momentum is not None: + m.momentum = float(momentum) + else: + m.momentum = None + m.eps = float(eps) + return + + def get_bn_param(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm): + return { + 'momentum': m.momentum, + 'eps': m.eps, + } + return None + + def reset_running_stats_for_calibration(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm): + m.training = True + m.momentum = None # cumulative moving average + m.reset_running_stats() + + +class AttentiveNasDynamicModel(nn.Module): + + def __init__(self, supernet_cfg, n_classes=1000, bn_param=(0., 1e-5)): + super(AttentiveNasDynamicModel, self).__init__() + + self.supernet_cfg = supernet_cfg + self.n_classes = n_classes + self.use_v3_head = getattr(self.supernet_cfg, 'use_v3_head', False) + self.stage_names = ['first_conv', 'mb1', 'mb2', 'mb3', 'mb4', 'mb5', 'mb6', 'mb7', 'last_conv'] + + self.width_list, self.depth_list, self.ks_list, self.expand_ratio_list = [], [], [], [] + for name in self.stage_names: + block_cfg = getattr(self.supernet_cfg, name) + self.width_list.append(block_cfg.c) + if name.startswith('mb'): + self.depth_list.append(block_cfg.d) + self.ks_list.append(block_cfg.k) + self.expand_ratio_list.append(block_cfg.t) + self.resolution_list = self.supernet_cfg.resolutions + + self.cfg_candidates = { + 'resolution': self.resolution_list, + 'width': self.width_list, + 'depth': self.depth_list, + 'kernel_size': self.ks_list, + 'expand_ratio': self.expand_ratio_list + } + + #first conv layer, including conv, bn, act + out_channel_list, act_func, stride = \ + self.supernet_cfg.first_conv.c, self.supernet_cfg.first_conv.act_func, self.supernet_cfg.first_conv.s + self.first_conv = DynamicConvLayer( + in_channel_list=val2list(3), out_channel_list=out_channel_list, + kernel_size=3, stride=stride, act_func=act_func, + ) + + # inverted residual blocks + self.block_group_info = [] + blocks = [] + _block_index = 0 + feature_dim = out_channel_list + for stage_id, key in enumerate(self.stage_names[1:-1]): + block_cfg = getattr(self.supernet_cfg, key) + width = block_cfg.c + n_block = max(block_cfg.d) + act_func = block_cfg.act_func + ks = block_cfg.k + expand_ratio_list = block_cfg.t + use_se = block_cfg.se + + self.block_group_info.append([_block_index + i for i in range(n_block)]) + _block_index += n_block + + output_channel = width + for i in range(n_block): + stride = block_cfg.s if i == 0 else 1 + if min(expand_ratio_list) >= 4: + expand_ratio_list = [_s for _s in expand_ratio_list if _s >= 4] if i == 0 else expand_ratio_list + mobile_inverted_conv = DynamicMBConvLayer( + in_channel_list=feature_dim, + out_channel_list=output_channel, + kernel_size_list=ks, + expand_ratio_list=expand_ratio_list, + stride=stride, + act_func=act_func, + use_se=use_se, + channels_per_group=getattr(self.supernet_cfg, 'channels_per_group', 1) + ) + # Rule: add skip-connect, and use 2x2 AvgPool or 1x1 Conv for adaptation + shortcut = DynamicShortcutLayer(feature_dim, output_channel, reduction=stride) + blocks.append(ResidualBlock(mobile_inverted_conv, shortcut)) + feature_dim = output_channel + self.blocks = nn.ModuleList(blocks) + + last_channel, act_func = self.supernet_cfg.last_conv.c, self.supernet_cfg.last_conv.act_func + if not self.use_v3_head: + self.last_conv = DynamicConvLayer( + in_channel_list=feature_dim, out_channel_list=last_channel, + kernel_size=1, act_func=act_func, + ) + else: + expand_feature_dim = [f_dim * 6 for f_dim in feature_dim] + self.last_conv = nn.Sequential(OrderedDict([ + ('final_expand_layer', DynamicConvLayer( + feature_dim, expand_feature_dim, kernel_size=1, use_bn=True, act_func=act_func) + ), + ('pool', nn.AdaptiveAvgPool2d((1,1))), + ('feature_mix_layer', DynamicConvLayer( + in_channel_list=expand_feature_dim, out_channel_list=last_channel, + kernel_size=1, act_func=act_func, use_bn=False,) + ), + ])) + + #final conv layer + self.classifier = DynamicLinearLayer( + in_features_list=last_channel, out_features=n_classes, bias=True + ) + + # set bn param + self.set_bn_param(momentum=bn_param[0], eps=bn_param[1]) + + # runtime_depth + self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info] + + self.zero_residual_block_bn_weights() + + self.active_dropout_rate = 0 + self.active_drop_connect_rate = 0 + self.active_resolution = 224 + + # Rule: Initialize learnable coefficient \gamma=0 + def zero_residual_block_bn_weights(self): + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, ResidualBlock): + if isinstance(m.mobile_inverted_conv, DynamicMBConvLayer) and m.shortcut is not None: + m.mobile_inverted_conv.point_linear.bn.bn.weight.zero_() + + @staticmethod + def name(): + return 'AttentiveNasModel' + + def forward(self, x): + # resize input to target resolution first + if x.size(-1) != self.active_resolution: + x = F.interpolate(x, size=self.active_resolution, mode='bicubic') + + # first conv + x = self.first_conv(x) + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + x = self.blocks[idx](x) + + x = self.last_conv(x) + x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling + x = torch.squeeze(x) + + if self.active_dropout_rate > 0 and self.training: + x = F.dropout(x, p = self.active_dropout_rate) + + x = self.classifier(x) + return x + + + @property + def module_str(self): + _str = self.first_conv.module_str + '\n' + _str += self.blocks[0].module_str + '\n' + + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + _str += self.blocks[idx].module_str + '\n' + if not self.use_v3_head: + _str += self.last_conv.module_str + '\n' + else: + _str += self.last_conv.final_expand_layer.module_str + '\n' + _str += self.last_conv.feature_mix_layer.module_str + '\n' + _str += self.classifier.module_str + '\n' + return _str + + @property + def config(self): + return { + 'name': AttentiveNasDynamicModel.__name__, + 'bn': self.get_bn_param(), + 'first_conv': self.first_conv.config, + 'blocks': [ + block.config for block in self.blocks + ], + 'last_conv': self.last_conv.config if not self.use_v3_head else None, + 'final_expand_layer': self.last_conv.final_expand_layer if self.use_v3_head else None, + 'feature_mix_layer': self.last_conv.feature_mix_layer if self.use_v3_head else None, + 'classifier': self.classifier.config, + 'resolution': self.active_resolution + } + + + @staticmethod + def build_from_config(config): + raise NotImplementedError + + def get_parameters(self, keys=None, mode="include"): + if keys is None: + for name, param in self.named_parameters(): + if param.requires_grad: + yield param + elif mode == "include": + for name, param in self.named_parameters(): + flag = False + for key in keys: + if key in name: + flag = True + break + if flag and param.requires_grad: + yield param + elif mode == "exclude": + for name, param in self.named_parameters(): + flag = True + for key in keys: + if key in name: + flag = False + break + if flag and param.requires_grad: + yield param + else: + raise ValueError("do not support: %s" % mode) + + def set_bn_param(self, momentum, eps): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm): + if momentum is not None: + m.momentum = float(momentum) + else: + m.momentum = None + m.eps = float(eps) + return + + def get_bn_param(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm): + return { + 'momentum': m.momentum, + 'eps': m.eps, + } + return None + + """ set, sample and get active sub-networks """ + def set_active_subnet(self, resolution=224, width=None, depth=None, kernel_size=None, expand_ratio=None, **kwargs): + assert len(depth) == len(kernel_size) == len(expand_ratio) == len(width) - 2 + #set resolution + self.active_resolution = resolution + + # first conv + self.first_conv.active_out_channel = width[0] + + for stage_id, (c, k, e, d) in enumerate(zip(width[1:-1], kernel_size, expand_ratio, depth)): + start_idx, end_idx = min(self.block_group_info[stage_id]), max(self.block_group_info[stage_id]) + for block_id in range(start_idx, start_idx+d): + block = self.blocks[block_id] + #block output channels + block.mobile_inverted_conv.active_out_channel = c + if block.shortcut is not None: + block.shortcut.active_out_channel = c + + #dw kernel size + block.mobile_inverted_conv.active_kernel_size = k + + #dw expansion ration + block.mobile_inverted_conv.active_expand_ratio = e + + #IRBlocks repated times + for i, d in enumerate(depth): + self.runtime_depth[i] = min(len(self.block_group_info[i]), d) + + #last conv + if not self.use_v3_head: + self.last_conv.active_out_channel = width[-1] + else: + # default expansion ratio: 6 + self.last_conv.final_expand_layer.active_out_channel = width[-2] * 6 + self.last_conv.feature_mix_layer.active_out_channel = width[-1] + + + def get_active_subnet_settings(self): + r = self.active_resolution + width, depth, kernel_size, expand_ratio= [], [], [], [] + + #first conv + width.append(self.first_conv.active_out_channel) + for stage_id in range(len(self.block_group_info)): + start_idx = min(self.block_group_info[stage_id]) + block = self.blocks[start_idx] #first block + width.append(block.mobile_inverted_conv.active_out_channel) + kernel_size.append(block.mobile_inverted_conv.active_kernel_size) + expand_ratio.append(block.mobile_inverted_conv.active_expand_ratio) + depth.append(self.runtime_depth[stage_id]) + + if not self.use_v3_head: + width.append(self.last_conv.active_out_channel) + else: + width.append(self.last_conv.feature_mix_layer.active_out_channel) + + return { + 'resolution': r, + 'width': width, + 'kernel_size': kernel_size, + 'expand_ratio': expand_ratio, + 'depth': depth, + } + + def set_dropout_rate(self, dropout=0, drop_connect=0, drop_connect_only_last_two_stages=True): + self.active_dropout_rate = dropout + for idx, block in enumerate(self.blocks): + if drop_connect_only_last_two_stages: + if idx not in self.block_group_info[-1] + self.block_group_info[-2]: + continue + this_drop_connect_rate = drop_connect * float(idx) / len(self.blocks) + block.drop_connect_rate = this_drop_connect_rate + + + def sample_min_subnet(self): + return self._sample_active_subnet(min_net=True) + + + def sample_max_subnet(self): + return self._sample_active_subnet(max_net=True) + + + def sample_active_subnet(self, compute_flops=False): + cfg = self._sample_active_subnet( + False, False + ) + if compute_flops: + cfg['flops'] = self.compute_active_subnet_flops() + return cfg + + + def sample_active_subnet_within_range(self, targeted_min_flops, targeted_max_flops): + while True: + cfg = self._sample_active_subnet() + cfg['flops'] = self.compute_active_subnet_flops() + if cfg['flops'] >= targeted_min_flops and cfg['flops'] <= targeted_max_flops: + return cfg + + def _sample_active_subnet(self, min_net=False, max_net=False): + + sample_cfg = lambda candidates, sample_min, sample_max: \ + min(candidates) if sample_min else (max(candidates) if sample_max else random.choice(candidates)) + + cfg = {} + # sample a resolution + cfg['resolution'] = sample_cfg(self.cfg_candidates['resolution'], min_net, max_net) + for k in ['width', 'depth', 'kernel_size', 'expand_ratio']: + cfg[k] = [] + for vv in self.cfg_candidates[k]: + cfg[k].append(sample_cfg(val2list(vv), min_net, max_net)) + + self.set_active_subnet( + cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio'] + ) + return cfg + + + def mutate_and_reset(self, cfg, prob=0.1, keep_resolution=False): + cfg = deepcopy(cfg) + pick_another = lambda x, candidates: x if len(candidates) == 1 else random.choice([v for v in candidates if v != x]) + # sample a resolution + r = random.random() + if r < prob and not keep_resolution: + cfg['resolution'] = pick_another(cfg['resolution'], self.cfg_candidates['resolution']) + + # sample channels, depth, kernel_size, expand_ratio + for k in ['width', 'depth', 'kernel_size', 'expand_ratio']: + for _i, _v in enumerate(cfg[k]): + r = random.random() + if r < prob: + cfg[k][_i] = pick_another(cfg[k][_i], val2list(self.cfg_candidates[k][_i])) + + self.set_active_subnet( + cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio'] + ) + return cfg + + + def crossover_and_reset(self, cfg1, cfg2, p=0.5): + def _cross_helper(g1, g2, prob): + assert type(g1) == type(g2) + if isinstance(g1, int): + return g1 if random.random() < prob else g2 + elif isinstance(g1, list): + return [v1 if random.random() < prob else v2 for v1, v2 in zip(g1, g2)] + else: + raise NotImplementedError + + cfg = {} + cfg['resolution'] = cfg1['resolution'] if random.random() < p else cfg2['resolution'] + for k in ['width', 'depth', 'kernel_size', 'expand_ratio']: + cfg[k] = _cross_helper(cfg1[k], cfg2[k], p) + + self.set_active_subnet( + cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio'] + ) + return cfg + + + def get_active_subnet(self, preserve_weight=True): + with torch.no_grad(): + first_conv = self.first_conv.get_active_subnet(3, preserve_weight) + + blocks = [] + input_channel = first_conv.out_channels + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append(ResidualBlock( + self.blocks[idx].mobile_inverted_conv.get_active_subnet(input_channel, preserve_weight), + self.blocks[idx].shortcut.get_active_subnet(input_channel, preserve_weight) if self.blocks[idx].shortcut is not None else None + )) + input_channel = stage_blocks[-1].mobile_inverted_conv.out_channels + blocks += stage_blocks + + if not self.use_v3_head: + last_conv = self.last_conv.get_active_subnet(input_channel, preserve_weight) + in_features = last_conv.out_channels + else: + final_expand_layer = self.last_conv.final_expand_layer.get_active_subnet(input_channel, preserve_weight) + feature_mix_layer = self.last_conv.feature_mix_layer.get_active_subnet(input_channel*6, preserve_weight) + in_features = feature_mix_layer.out_channels + last_conv = nn.Sequential( + final_expand_layer, + nn.AdaptiveAvgPool2d((1,1)), + feature_mix_layer + ) + + classifier = self.classifier.get_active_subnet(in_features, preserve_weight) + + _subnet = AttentiveNasStaticModel( + first_conv, blocks, last_conv, classifier, self.active_resolution, use_v3_head=self.use_v3_head + ) + _subnet.set_bn_param(**self.get_bn_param()) + return _subnet + + + def compute_active_subnet_flops(self): + + def count_conv(c_in, c_out, size_out, groups, k): + kernel_ops = k**2 + output_elements = c_out * size_out**2 + ops = c_in * output_elements * kernel_ops / groups + return ops + + def count_linear(c_in, c_out): + return c_in * c_out + + total_ops = 0 + + c_in = 3 + size_out = self.active_resolution // self.first_conv.stride + c_out = self.first_conv.active_out_channel + + total_ops += count_conv(c_in, c_out, size_out, 1, 3) + c_in = c_out + + # mb blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + block = self.blocks[idx] + c_middle = make_divisible(round(c_in * block.mobile_inverted_conv.active_expand_ratio), 8) + # 1*1 conv + if block.mobile_inverted_conv.inverted_bottleneck is not None: + total_ops += count_conv(c_in, c_middle, size_out, 1, 1) + # dw conv + stride = 1 if idx > active_idx[0] else block.mobile_inverted_conv.stride + if size_out % stride == 0: + size_out = size_out // stride + else: + size_out = (size_out +1) // stride + total_ops += count_conv(c_middle, c_middle, size_out, c_middle, block.mobile_inverted_conv.active_kernel_size) + # 1*1 conv + c_out = block.mobile_inverted_conv.active_out_channel + total_ops += count_conv(c_middle, c_out, size_out, 1, 1) + #se + if block.mobile_inverted_conv.use_se: + num_mid = make_divisible(c_middle // block.mobile_inverted_conv.depth_conv.se.reduction, divisor=8) + total_ops += count_conv(c_middle, num_mid, 1, 1, 1) * 2 + if block.shortcut and c_in != c_out: + total_ops += count_conv(c_in, c_out, size_out, 1, 1) + c_in = c_out + + if not self.use_v3_head: + c_out = self.last_conv.active_out_channel + total_ops += count_conv(c_in, c_out, size_out, 1, 1) + else: + c_expand = self.last_conv.final_expand_layer.active_out_channel + c_out = self.last_conv.feature_mix_layer.active_out_channel + total_ops += count_conv(c_in, c_expand, size_out, 1, 1) + total_ops += count_conv(c_expand, c_out, 1, 1, 1) + + # n_classes + total_ops += count_linear(c_out, self.n_classes) + return total_ops / 1e6 + + + def load_weights_from_pretrained_models(self, checkpoint_path): + with open(checkpoint_path, 'rb') as f: + checkpoint = torch.load(f, map_location='cpu') + assert isinstance(checkpoint, dict) + pretrained_state_dicts = checkpoint['state_dict'] + for k, v in self.state_dict().items(): + name = 'module.' + k if not k.startswith('module') else k + v.copy_(pretrained_state_dicts[name]) + + +def _AttentiveNAS_CNN(): + from xnas.core.config import cfg + bn_momentum = cfg.ATTENTIVENAS.BN_MOMENTUM + bn_eps = cfg.ATTENTIVENAS.BN_EPS + return AttentiveNasDynamicModel( + cfg.ATTENTIVENAS.SUPERNET_CFG, + cfg.LOADER.NUM_CLASSES, + (bn_momentum, bn_eps), + ) + +def _infer_AttentiveNAS_CNN(): + from xnas.core.config import cfg + bn_momentum = cfg.ATTENTIVENAS.BN_MOMENTUM + bn_eps = cfg.ATTENTIVENAS.BN_EPS + supernet = AttentiveNasDynamicModel( + cfg.ATTENTIVENAS.SUPERNET_CFG, + cfg.LOADER.NUM_CLASSES, + (bn_momentum, bn_eps), + ) + # namespace changed: pareto_models.supernet_checkpoint_path + supernet.load_weights_from_pretrained_models(cfg.SEARCH.WEIGHT) + # namespace created: active_subnet.* + supernet.set_active_subnet( + resolution=cfg.ATTENTIVENAS.ACTIVE_SUBNET.RESOLUTION, + width = cfg.ATTENTIVENAS.ACTIVE_SUBNET.WIDTH, + depth = cfg.ATTENTIVENAS.ACTIVE_SUBNET.DEPTH, + kernel_size = cfg.ATTENTIVENAS.ACTIVE_SUBNET.KERNEL_SIZE, + expand_ratio = cfg.ATTENTIVENAS.ACTIVE_SUBNET.EXPAND_RATIO, + ) + model = supernet.get_active_subnet() + # house-keeping stuff: may using different values with supernet + model.set_bn_param(momentum=bn_momentum, eps=bn_eps) + del supernet + return model -- 2.34.1 From 1779aaa6e6adf8ca2dac27c30761ea8e3a8007fd Mon Sep 17 00:00:00 2001 From: xfey Date: Wed, 22 Jun 2022 21:56:40 +0800 Subject: [PATCH 7/7] remove [debug] labels --- scripts/search/BigNAS/search.py | 4 ---- scripts/search/BigNAS/train_supernet.py | 8 -------- 2 files changed, 12 deletions(-) diff --git a/scripts/search/BigNAS/search.py b/scripts/search/BigNAS/search.py index 4a52a01..2bb3f22 100644 --- a/scripts/search/BigNAS/search.py +++ b/scripts/search/BigNAS/search.py @@ -163,10 +163,6 @@ def test_epoch(subnet, test_loader, test_meter): test_meter.reset(True) test_meter.iter_tic() for cur_iter, (inputs, labels) in enumerate(test_loader): - # [debug] - if cur_iter > 20: - break - inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) preds = subnet(inputs) top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) diff --git a/scripts/search/BigNAS/train_supernet.py b/scripts/search/BigNAS/train_supernet.py index bc968bd..bc984b3 100644 --- a/scripts/search/BigNAS/train_supernet.py +++ b/scripts/search/BigNAS/train_supernet.py @@ -108,10 +108,6 @@ class BigNASTrainer(Trainer): self.train_meter.iter_tic() self.train_loader.sampler.set_epoch(cur_epoch) # DDP for cur_iter, (inputs, labels) in enumerate(self.train_loader): - # [debug] - if cur_iter > 20: - break - inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True) # Adjust lr per iter @@ -183,10 +179,6 @@ class BigNASTrainer(Trainer): self.test_meter.reset(True) self.test_meter.iter_tic() for cur_iter, (inputs, labels) in enumerate(self.test_loader): - # [debug] - if cur_iter > 20: - break - inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True) preds = subnet(inputs) top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) -- 2.34.1