#6 dev

Merged
xfey merged 5 commits from dev into master 2 years ago
  1. +18
    -2
      configs/search/DrNAS/nb201_cifar10_DARTS.yaml
  2. +18
    -2
      configs/search/DrNAS/nb201_cifar10_Dirichlet.yaml
  3. +17
    -0
      configs/search/DrNAS/nb201_cifar10_GDAS.yaml
  4. +18
    -2
      configs/search/DrNAS/nb201_cifar10_SNAS.yaml
  5. +16
    -184
      search/DrNAS/DARTSspace.py
  6. +5
    -145
      search/DrNAS/DARTSspace_img.py
  7. +217
    -255
      search/DrNAS/nb201space.py
  8. +241
    -267
      search/DrNAS/nb201space_progressive.py
  9. +24
    -19
      search/PCDARTS_search.py
  10. +44
    -39
      tools/test_func/data_loader_test.py
  11. +6
    -5
      xnas/core/builders.py
  12. +3
    -0
      xnas/core/config.py
  13. +0
    -160
      xnas/datasets/cifar10.py
  14. +3
    -2
      xnas/datasets/imagenet.py
  15. +9
    -1
      xnas/datasets/imagenet16.py
  16. +128
    -45
      xnas/datasets/loader.py
  17. +0
    -123
      xnas/datasets/torchDataset.py
  18. +117
    -2
      xnas/datasets/transforms.py
  19. +3
    -3
      xnas/search_space/DrNAS/DARTSspace/cnn.py
  20. +4
    -4
      xnas/search_space/DrNAS/nb201space/cnn.py

+ 18
- 2
configs/search/DrNAS/nb201_cifar10_DARTS.yaml View File

@@ -5,9 +5,25 @@ SPACE:
NODES: 4
SEARCH:
DATASET: 'cifar10'
DATAPATH: "data/cifar10"
NUM_CLASSES: 10
SPLIT: [0.5, 0.5]
LOSS_FUN: "cross_entropy"
BATCH_SIZE: 64
DATA_LOADER:
NUM_WORKERS: 4
OPTIM:
BASE_LR: 0.025
MIN_LR: 0.001
MOMENTUM: 0.9
WEIGHT_DECAY: 3e-4
GRAD_CLIP: 5.0
MAX_EPOCH: 100
DRNAS:
K: 1
UNROLLED: False
METHOD: 'darts'
TAU: [1, 10]
K: 1 # k=4 for progressive
REG_TYPE: "l2"
REG_SCALE: 1e-3
REG_SCALE: 1e-3
RNG_SEED: 2

+ 18
- 2
configs/search/DrNAS/nb201_cifar10_Dirichlet.yaml View File

@@ -5,9 +5,25 @@ SPACE:
NODES: 4
SEARCH:
DATASET: 'cifar10'
DATAPATH: "data/cifar10"
NUM_CLASSES: 10
SPLIT: [0.5, 0.5]
LOSS_FUN: "cross_entropy"
BATCH_SIZE: 64
DATA_LOADER:
NUM_WORKERS: 4
OPTIM:
BASE_LR: 0.025
MIN_LR: 0.001
MOMENTUM: 0.9
WEIGHT_DECAY: 3e-4
GRAD_CLIP: 5.0
MAX_EPOCH: 100
DRNAS:
K: 1
UNROLLED: False
METHOD: 'dirichlet'
TAU: [1, 10]
K: 1 # k=4 for progressive
REG_TYPE: "l2" # can be "kl" for Dirichlet
REG_SCALE: 1e-3
REG_SCALE: 1e-3
RNG_SEED: 2

+ 17
- 0
configs/search/DrNAS/nb201_cifar10_GDAS.yaml View File

@@ -5,5 +5,22 @@ SPACE:
NODES: 4
SEARCH:
DATASET: 'cifar10'
DATAPATH: "data/cifar10"
NUM_CLASSES: 10
SPLIT: [0.5, 0.5]
LOSS_FUN: "cross_entropy"
BATCH_SIZE: 64
DATA_LOADER:
NUM_WORKERS: 4
OPTIM:
BASE_LR: 0.025
MIN_LR: 0.001
MOMENTUM: 0.9
WEIGHT_DECAY: 3e-4
GRAD_CLIP: 5.0
MAX_EPOCH: 100
DRNAS:
UNROLLED: False
METHOD: 'gdas'
TAU: [1, 10]
RNG_SEED: 2

+ 18
- 2
configs/search/DrNAS/nb201_cifar10_SNAS.yaml View File

@@ -5,9 +5,25 @@ SPACE:
NODES: 4
SEARCH:
DATASET: 'cifar10'
DATAPATH: "data/cifar10"
NUM_CLASSES: 10
SPLIT: [0.5, 0.5]
LOSS_FUN: "cross_entropy"
BATCH_SIZE: 64
DATA_LOADER:
NUM_WORKERS: 4
OPTIM:
BASE_LR: 0.025
MIN_LR: 0.001
MOMENTUM: 0.9
WEIGHT_DECAY: 3e-4
GRAD_CLIP: 5.0
MAX_EPOCH: 100
DRNAS:
K: 1
UNROLLED: False
METHOD: 'snas'
TAU: [1, 10]
K: 1 # k=4 for progressive
REG_TYPE: "l2"
REG_SCALE: 1e-3
REG_SCALE: 1e-3
RNG_SEED: 2

+ 16
- 184
search/DrNAS/DARTSspace.py View File

@@ -28,103 +28,14 @@ writer = SummaryWriter(log_dir=os.path.join(cfg.OUT_DIR, "tb"))
logger = logging.get_logger(__name__)


# parser = argparse.ArgumentParser("cifar")
# parser.add_argument(
# "--data", type=str, default="datapath", help="location of the data corpus"
# )
# parser.add_argument(
# "--dataset", type=str, default="cifar10", help="location of the data corpus"
# )
# parser.add_argument("--batch_size", type=int, default=64, help="batch size")
# parser.add_argument(
# "--learning_rate", type=float, default=0.1, help="init learning rate"
# )
# parser.add_argument(
# "--learning_rate_min", type=float, default=0.0, help="min learning rate"
# )
# parser.add_argument("--momentum", type=float, default=0.9, help="momentum")
# parser.add_argument("--weight_decay", type=float, default=3e-4, help="weight decay")
# parser.add_argument("--report_freq", type=float, default=50, help="report frequency")
# parser.add_argument("--gpu", type=int, default=0, help="gpu device id")
# parser.add_argument(
# "--init_channels", type=int, default=36, help="num of init channels"
# )
# parser.add_argument("--layers", type=int, default=20, help="total number of layers")
# parser.add_argument("--save", type=str, default="exp", help="experiment name")
# parser.add_argument("--seed", type=int, default=2, help="random seed")
# parser.add_argument("--grad_clip", type=float, default=5, help="gradient clipping")
# parser.add_argument(
# "--train_portion", type=float, default=0.5, help="portion of training data"
# )
# parser.add_argument(
# "--unrolled",
# action="store_true",
# default=False,
# help="use one-step unrolled validation loss",
# )
# parser.add_argument(
# "--arch_learning_rate",
# type=float,
# default=6e-4,
# help="learning rate for arch encoding",
# )
# parser.add_argument("--k", type=int, default=6, help="init partial channel parameter")
#### regularization
# parser.add_argument(
# "--reg_type",
# type=str,
# default="l2",
# choices=["l2", "kl"],
# help="regularization type",
# )
# parser.add_argument(
# "--reg_scale",
# type=float,
# default=1e-3,
# help="scaling factor of the regularization term, default value is proper for l2, for kl you might adjust reg_scale to match l2",
# )
# args = parser.parse_args()


# cfg.OUT_DIR = "../experiments/{}/search-progressive-{}-{}-{}".format(
# cfg.SEARCH.DATASET, cfg.OUT_DIR, time.strftime("%Y%m%d-%H%M%S"), cfg.RNG_SEED
# )
# cfg.OUT_DIR += "-init_channels-" + str(cfg.SPACE.CHANNEL)
# cfg.OUT_DIR += "-layers-" + str(cfg.SPACE.LAYERS)
# cfg.OUT_DIR += "-init_pc-" + str(cfg.DRNAS.K)
# utils.create_exp_dir(cfg.OUT_DIR, scripts_to_save=glob.glob("*.py"))

# log_format = "%(asctime)s %(message)s"
# logging.basicConfig(
# stream=sys.stdout,
# level=logging.INFO,
# format=log_format,
# datefmt="%m/%d %I:%M:%S %p",
# )
# fh = logging.FileHandler(os.path.join(cfg.OUT_DIR, "log.txt"))
# fh.setFormatter(logging.Formatter(log_format))
# logging.getLogger().addHandler(fh)


def main():

setup_env()
cudnn.benchmark = True # DrNAS code sets this term to True.

criterion = build_loss_fun().cuda()
# criterion = nn.CrossEntropyLoss()
# criterion = criterion.cuda()

model = DrNAS_builder().cuda()
# model = Network(
# cfg.SPACE.CHANNEL,
# cfg.SEARCH.NUM_CLASSES,
# cfg.SPACE.LAYERS,
# criterion,
# k=cfg.DRNAS.K,
# reg_type=cfg.DRNAS.REG_TYPE,
# reg_scale=cfg.DRNAS.REG_SCALE,
# )
architect = Architect(model, cfg)

logger.info("param size = %fMB", utils.count_parameters_in_MB(model))
@@ -136,38 +47,10 @@ def main():
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
)

# train_transform, valid_transform = utils._data_transforms_cifar10(args)
# if cfg.SEARCH.DATASET == "cifar100":
# train_data = dset.CIFAR100(
# root=cfg.SEARCH.DATAPATH, train=True, download=True, transform=train_transform
# )
# else:
# train_data = dset.CIFAR10(
# root=cfg.SEARCH.DATAPATH, train=True, download=True, transform=train_transform
# )

[train_loader, valid_loader] = construct_loader(
cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE, cfg.SEARCH.DATAPATH
)

# num_train = len(train_data)
# indices = list(range(num_train))
# split = int(np.floor(args.train_portion * num_train))

# train_queue = torch.utils.data.DataLoader(
# train_data,
# batch_size=cfg.SEARCH.BATCH_SIZE,
# sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
# pin_memory=True,
# )

# valid_queue = torch.utils.data.DataLoader(
# train_data,
# batch_size=cfg.SEARCH.BATCH_SIZE,
# sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
# pin_memory=True,
# )

# configure progressive parameter
epoch = 0
ks = [6, 4]
@@ -178,13 +61,14 @@ def main():
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(sum(train_epochs)), eta_min=cfg.OPTIM.MIN_LR
)

train_meter = meters.TrainMeter(len(train_loader))
val_meter = meters.TestMeter(len(valid_loader))

train_timer = Timer()
for i, current_epochs in enumerate(train_epochs):
logger.info("train period #{} total epochs {}".format(i, current_epochs))
for e in range(current_epochs):
# train_timer = Timer()
for i, current_epoch in enumerate(train_epochs):
logger.info("train period #{} total epochs {}".format(i, current_epoch))
for e in range(current_epoch):
lr = lr_scheduler.get_lr()[0]
logger.info("epoch %d lr %e", epoch, lr)

@@ -192,9 +76,9 @@ def main():
logger.info("genotype = %s", genotype)
model.show_arch_parameters(logger)

train_timer.tic()
# train_timer.tic()
# training
train_acc = train_epoch(
top1err = train_epoch(
train_loader,
valid_loader,
model,
@@ -205,24 +89,23 @@ def main():
train_meter,
e,
)
logger.info("train_acc %f", train_acc)

train_timer.toc()
print("epoch time:{}".format(train_timer.diff))
logger.info("Top1 err:%f", top1err)
# train_timer.toc()
# print("epoch time:{}".format(train_timer.diff))

# validation
# valid_acc, valid_obj = infer(valid_queue, model, criterion)
# logger.info("valid_acc %f", valid_acc)
test_epoch(valid_loader, model, val_meter, current_epochs, writer)
test_epoch(valid_loader, model, val_meter, epoch, writer)

epoch += 1
lr_scheduler.step()

if epoch % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
save_ckpt(model, os.path.join(cfg.OUT_DIR, "weights_epo" + str(epoch) + ".pt"))
utils.save(
model, os.path.join(cfg.OUT_DIR, "weights_epo" + str(epoch) + ".pt")
)

print("avg epoch time:{}".format(train_timer.average_time))
train_timer.reset()
# print("avg epoch time:{}".format(train_timer.average_time))
# train_timer.reset()

if not i == len(train_epochs) - 1:
model.pruning(num_keeps[i + 1])
@@ -268,9 +151,6 @@ def train_epoch(

valid_loader_iter = iter(valid_loader)

# objs = utils.AvgrageMeter()
# top1 = utils.AvgrageMeter()
# top5 = utils.AvgrageMeter()
for cur_iter, (trn_X, trn_y) in enumerate(train_loader):
model.train()
try:
@@ -319,54 +199,6 @@ def train_epoch(
train_meter.reset()
return top1_err

# prec1, prec5 = utils.accuracy(logits, trn_y, topk=(1, 5))
# objs.update(loss.data, n)
# top1.update(prec1.data, n)
# top5.update(prec5.data, n)

# if step % cfg.SEARCH.EVAL_PERIOD == 0:
# logger.info("train %03d %e %f %f", step, objs.avg, top1.avg, top5.avg)
# if "debug" in cfg.OUT_DIR:
# break

# return top1.avg, objs.avg


# def infer(valid_queue, model, criterion):
# objs = utils.AvgrageMeter()
# top1 = utils.AvgrageMeter()
# top5 = utils.AvgrageMeter()
# model.eval()

# with torch.no_grad():
# for step, (input, target) in enumerate(valid_queue):
# input = input.cuda()
# target = target.cuda(non_blocking=True)

# logits = model(input)
# loss = criterion(logits, target)

# prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
# n = input.size(0)
# objs.update(loss.data, n)
# top1.update(prec1.data, n)
# top5.update(prec5.data, n)

# if step % cfg.SEARCH.EVAL_PERIOD == 0:
# logger.info("valid %03d %e %f %f", step, objs.avg, top1.avg, top5.avg)
# if "debug" in cfg.OUT_DIR:
# break

# return top1.avg, objs.avg


def save_ckpt(model, model_path):
torch.save(model.state_dict(), model_path)


def load_ckpt(model, model_path):
model.load_state_dict(torch.load(model_path))


if __name__ == "__main__":
main()


+ 5
- 145
search/DrNAS/DARTSspace_img.py View File

@@ -30,84 +30,6 @@ writer = SummaryWriter(log_dir=os.path.join(cfg.OUT_DIR, "tb"))

logger = logging.get_logger(__name__)

# parser = argparse.ArgumentParser("imagenet")
# parser.add_argument(
# "--workers", type=int, default=16, help="number of workers to load dataset"
# )
# parser.add_argument(
# "--data", type=str, default="datapath", help="location of the data corpus"
# )
# parser.add_argument("--batch_size", type=int, default=512, help="batch size")
# parser.add_argument(
# "--learning_rate", type=float, default=0.5, help="init learning rate"
# )
# parser.add_argument(
# "--learning_rate_min", type=float, default=0.0, help="min learning rate"
# )
# parser.add_argument("--momentum", type=float, default=0.9, help="momentum")
# parser.add_argument("--weight_decay", type=float, default=3e-4, help="weight decay")
# parser.add_argument("--report_freq", type=float, default=50, help="report frequency")
# parser.add_argument(
# "--init_channels", type=int, default=48, help="num of init channels"
# )
# parser.add_argument("--layers", type=int, default=14, help="total number of layers")

# NOTE: cutout and drop_path_prob are never used.

# parser.add_argument("--cutout", action="store_true", default=False, help="use cutout")
# parser.add_argument("--cutout_length", type=int, default=16, help="cutout length")
# parser.add_argument(
# "--drop_path_prob", type=float, default=0.3, help="drop path probability"
# )

# parser.add_argument("--save", type=str, default="exp", help="experiment name")
# parser.add_argument("--seed", type=int, default=0, help="random seed")
# parser.add_argument("--grad_clip", type=float, default=5, help="gradient clipping")
# parser.add_argument(
# "--unrolled",
# action="store_true",
# default=False,
# help="use one-step unrolled validation loss",
# )
# parser.add_argument(
# "--arch_learning_rate",
# type=float,
# default=6e-3,
# help="learning rate for arch encoding",
# )
# parser.add_argument(
# "--arch_weight_decay",
# type=float,
# default=1e-3,
# help="weight decay for arch encoding",
# )
# parser.add_argument("--k", type=int, default=6, help="init partial channel parameter")
# parser.add_argument("--begin", type=int, default=10, help="warm start")

# args = parser.parse_args()

# args.save = "../experiments/imagenet/search-progressive-{}-{}-{}".format(
# args.save, time.strftime("%Y%m%d-%H%M%S"), args.seed
# )
# args.save += "-init_channels-" + str(args.init_channels)
# args.save += "-layers-" + str(args.layers)
# args.save += "-init_pc-" + str(args.k)
# utils.create_exp_dir(args.save, scripts_to_save=glob.glob("*.py"))

# log_format = "%(asctime)s %(message)s"
# logging.basicConfig(
# stream=sys.stdout,
# level=logger.info,
# format=log_format,
# datefmt="%m/%d %I:%M:%S %p",
# )
# fh = logging.FileHandler(os.path.join(args.save, "log.txt"))
# fh.setFormatter(logging.Formatter(log_format))
# logging.getLogger().addHandler(fh)

# data preparation, we random sample 10% and 2.5% from training set(each class) as train and val, respectively.
# Note that the data sampling can not use torch.utils.data.sampler.SubsetRandomSampler as imagenet is too large


def data_preparation():
traindir = os.path.join(cfg.SEAECH.DATAPATH, "train")
@@ -170,11 +92,9 @@ def main():
cudnn.benchmark = True # DrNAS code sets this term to True.

criterion = build_loss_fun().cuda()
# criterion = nn.CrossEntropyLoss()
# criterion = criterion.cuda()

model = DrNAS_builder()
# model = Network(cfg.SPACE.CHANNEL, cfg.SEARCH.NUM_CLASSES, cfg.SPACE.LAYERS, criterion, k=cfg.DRNAS.K)
model = nn.DataParallel(model) # TODO: parallel not tested
model = model.cuda()

@@ -231,7 +151,7 @@ def main():

train_timer.tic()
# training
train_acc = train_epoch(
top1err = train_epoch(
train_loader,
valid_loader,
model,
@@ -242,7 +162,7 @@ def main():
train_meter,
e,
)
logger.info("Train_acc %f", train_acc)
logger.info("Top1 err:%f", top1err)

train_timer.toc()
print("epoch time:{}".format(train_timer.diff))
@@ -253,13 +173,13 @@ def main():
# logger.info("Valid_acc %f", valid_acc)
# test_acc, test_obj = infer(test_queue, model, criterion)
# logger.info('Test_acc %f', test_acc)
test_epoch(valid_loader, model, val_meter, current_epochs, writer)
test_epoch(valid_loader, model, val_meter, epoch, writer)

epoch += 1
scheduler.step()

if epoch % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
save_ckpt(
utils.save(
model, os.path.join(cfg.OUT_DIR, "weights_epo" + str(epoch) + ".pt")
)

@@ -309,10 +229,6 @@ def train_epoch(

valid_loader_iter = iter(valid_loader)

# objs = utils.AvgrageMeter()
# top1 = utils.AvgrageMeter()
# top5 = utils.AvgrageMeter()

for cur_iter, (trn_X, trn_y) in enumerate(train_loader):
model.train()
try:
@@ -364,62 +280,6 @@ def train_epoch(
train_meter.reset()
return top1_err

# if step % args.report_freq == 0:
# end_time = time.time()
# if step == 0:
# duration = 0
# start_time = time.time()
# else:
# duration = end_time - start_time
# start_time = time.time()
# logger.info(
# "TRAIN Step: %03d Objs: %e R1: %f R5: %f Duration: %ds",
# step,
# objs.avg,
# top1.avg,
# top5.avg,
# duration,
# )
# if "debug" in args.save:
# break
# return top1.avg, objs.avg


# def infer(valid_queue, model, criterion):
# objs = utils.AvgrageMeter()
# top1 = utils.AvgrageMeter()
# top5 = utils.AvgrageMeter()
# model.eval()

# with torch.no_grad():
# for step, (input, target) in enumerate(valid_queue):
# input = input.cuda()
# target = target.cuda(non_blocking=True)

# logits = model(input)
# loss = criterion(logits, target)

# prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
# n = input.size(0)
# objs.update(loss.data, n)
# top1.update(prec1.data, n)
# top5.update(prec5.data, n)

# if step % args.report_freq == 0:
# logger.info("valid %03d %e %f %f", step, objs.avg, top1.avg, top5.avg)
# if "debug" in args.save:
# break

# return top1.avg, objs.avg


def save_ckpt(model, model_path):
torch.save(model.state_dict(), model_path)


def load_ckpt(model, model_path):
model.load_state_dict(torch.load(model_path))


if __name__ == "__main__":
main()


+ 217
- 255
search/DrNAS/nb201space.py View File

@@ -1,323 +1,285 @@
import os
import sys
import time
import glob
import numpy as np
import torch
import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from xnas.core.builders import DrNAS_builder

import xnas.core.config as config
import xnas.core.meters as meters
import xnas.core.logging as logging
import xnas.search_space.DrNAS.utils as utils
from xnas.search_algorithm.DrNAS import Architect
from xnas.core.builders import build_loss_fun, DrNAS_builder
from xnas.core.config import cfg
from xnas.core.timer import Timer
from xnas.core.trainer import setup_env, test_epoch
from xnas.datasets.loader import construct_loader

from torch.utils.tensorboard import SummaryWriter
from nas_201_api import NASBench201API as API

from xnas.core.config import cfg


# Load config and check
config.load_cfg_fom_args()
config.assert_and_infer_cfg()
cfg.freeze()
# Tensorboard supplement
writer = SummaryWriter(log_dir=os.path.join(cfg.OUT_DIR, "tb"))

parser = argparse.ArgumentParser("sota")
parser.add_argument('--data', type=str, default='datapath', help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset')
parser.add_argument('--method', type=str, default='dirichlet', help='choose nas method')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--epochs', type=int, default=100, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
parser.add_argument('--save', type=str, default='exp', help='experiment name')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss')
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
parser.add_argument('--tau_max', type=float, default=10, help='Max temperature (tau) for the gumbel softmax.')
parser.add_argument('--tau_min', type=float, default=1, help='Min temperature (tau) for the gumbel softmax.')
parser.add_argument('--k', type=int, default=1, help='partial channel parameter')
#### regularization
parser.add_argument('--reg_type', type=str, default='l2', choices=[
'l2', 'kl'], help='regularization type, kl is implemented for dirichlet only')
parser.add_argument('--reg_scale', type=float, default=1e-3,
help='scaling factor of the regularization term, default value is proper for l2, for kl you might adjust reg_scale to match l2')
args = parser.parse_args()

args.save = '../experiments/nasbench201/{}-search-{}-{}-{}'.format(
args.method, args.save, time.strftime("%Y%m%d-%H%M%S"), args.seed)
if not args.dataset == 'cifar10':
args.save += '-' + args.dataset
if args.unrolled:
args.save += '-unrolled'
if not args.weight_decay == 3e-4:
args.save += '-weight_l2-' + str(args.weight_decay)
if not args.arch_weight_decay == 1e-3:
args.save += '-alpha_l2-' + str(args.arch_weight_decay)
if not args.method == 'gdas':
args.save += '-pc-' + str(args.k)

utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
writer = SummaryWriter(args.save + '/runs')


if args.dataset == 'cifar100':
n_classes = 100
elif args.dataset == 'imagenet16-120':
n_classes = 120
else:
n_classes = 10
logger = logging.get_logger(__name__)


def distill(result):
result = result.split('\n')
cifar10 = result[5].replace(' ', '').split(':')
cifar100 = result[7].replace(' ', '').split(':')
imagenet16 = result[9].replace(' ', '').split(':')

cifar10_train = float(cifar10[1].strip(',test')[-7:-2].strip('='))
cifar10_test = float(cifar10[2][-7:-2].strip('='))
cifar100_train = float(cifar100[1].strip(',valid')[-7:-2].strip('='))
cifar100_valid = float(cifar100[2].strip(',test')[-7:-2].strip('='))
cifar100_test = float(cifar100[3][-7:-2].strip('='))
imagenet16_train = float(imagenet16[1].strip(',valid')[-7:-2].strip('='))
imagenet16_valid = float(imagenet16[2].strip(',test')[-7:-2].strip('='))
imagenet16_test = float(imagenet16[3][-7:-2].strip('='))

return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test
result = result.split("\n")
cifar10 = result[5].replace(" ", "").split(":")
cifar100 = result[7].replace(" ", "").split(":")
imagenet16 = result[9].replace(" ", "").split(":")

cifar10_train = float(cifar10[1].strip(",test")[-7:-2].strip("="))
cifar10_test = float(cifar10[2][-7:-2].strip("="))
cifar100_train = float(cifar100[1].strip(",valid")[-7:-2].strip("="))
cifar100_valid = float(cifar100[2].strip(",test")[-7:-2].strip("="))
cifar100_test = float(cifar100[3][-7:-2].strip("="))
imagenet16_train = float(imagenet16[1].strip(",valid")[-7:-2].strip("="))
imagenet16_valid = float(imagenet16[2].strip(",test")[-7:-2].strip("="))
imagenet16_test = float(imagenet16[3][-7:-2].strip("="))

return (
cifar10_train,
cifar10_test,
cifar100_train,
cifar100_valid,
cifar100_test,
imagenet16_train,
imagenet16_valid,
imagenet16_test,
)


def main():
torch.set_num_threads(3)
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)

np.random.seed(args.seed)
torch.cuda.set_device(args.gpu)
setup_env()
# follow DrNAS settings.
torch.set_num_threads(3)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled = True
torch.cuda.manual_seed(args.seed)
logging.info('gpu device = %d' % args.gpu)
logging.info("args = %s", args)
if not 'debug' in args.save:
api = API('pth file path')
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()

assert args.method in ['gdas', 'snas', 'dirichlet', 'darts'], "method not supported."

if args.method == 'gdas' or args.method == 'snas':

if not "debug" in cfg.OUT_DIR:
api = API("./data/NAS-Bench-201-v1_1-096897.pth")

criterion = build_loss_fun().cuda()

assert cfg.DRNAS.METHOD in [
"gdas",
"snas",
"dirichlet",
"darts",
], "method not supported."

if cfg.DRNAS.METHOD == "gdas" or cfg.DRNAS.METHOD == "snas":
[tau_min, tau_max] = cfg.DRNAS.TAU
# Create the decrease step for the gumbel softmax temperature
tau_step = (args.tau_min - args.tau_max) / args.epochs
tau_epoch = args.tau_max
tau_step = (tau_min - tau_max) / cfg.OPTIM.MAX_EPOCH
tau_epoch = tau_max

model = DrNAS_builder().cuda()

# if args.method == 'gdas':
# model = TinyNetworkGDAS(C=cfg.SPACE.CHANNEL, N=cfg.SPACE.LAYERS, max_nodes=cfg.SPACE.NODES, num_classes=cfg.SEARCH.NUM_CLASSES,
# criterion=criterion, search_space=NAS_BENCH_201)
# elif args.method == 'snas':
# model = TinyNetwork(C=cfg.SPACE.CHANNEL, N=cfg.SPACE.LAYERS, max_nodes=cfg.SPACE.NODES, num_classes=cfg.SEARCH.NUM_CLASSES,
# criterion=criterion, search_space=NAS_BENCH_201, k=args.k, species='gumbel',
# reg_type="l2", reg_scale=1e-3)
# elif args.method == 'dirichlet':
# model = TinyNetwork(C=cfg.SPACE.CHANNEL, N=cfg.SPACE.LAYERS, max_nodes=cfg.SPACE.NODES, num_classes=cfg.SEARCH.NUM_CLASSES,
# criterion=criterion, search_space=NAS_BENCH_201, k=args.k, species='dirichlet',
# reg_type=args.reg_type, reg_scale=args.reg_scale)
# elif args.method == 'darts':
# model = TinyNetwork(C=cfg.SPACE.CHANNEL, N=cfg.SPACE.LAYERS, max_nodes=cfg.SPACE.NODES, num_classes=cfg.SEARCH.NUM_CLASSES,
# criterion=criterion, search_space=NAS_BENCH_201, k=args.k, species='softmax',
# reg_type="l2", reg_scale=1e-3)
# model = model.cuda()
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
logger.info("param size = %fMB", utils.count_parameters_in_MB(model))

optimizer = torch.optim.SGD(
model.get_weights(),
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay)

if args.dataset == 'cifar10':
train_transform, valid_transform = utils._data_transforms_cifar10(args)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
elif args.dataset == 'cifar100':
train_transform, valid_transform = utils._data_transforms_cifar100(args)
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
elif args.dataset == 'svhn':
train_transform, valid_transform = utils._data_transforms_svhn(args)
train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform)
elif args.dataset == 'imagenet16-120':
import torchvision.transforms as transforms
from xnas.datasets.imagenet16 import ImageNet16
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26, 65.09]]
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
train_transform = transforms.Compose(lists)
train_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120)
assert len(train_data) == 151700

num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(args.train_portion * num_train))

train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True)

valid_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=True)
cfg.OPTIM.BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
)

train_loader, valid_loader = construct_loader(
cfg.SEARCH.DATASET,
cfg.SEARCH.SPLIT,
cfg.SEARCH.BATCH_SIZE,
cfg.SEARCH.DATAPATH,
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(args.epochs), eta_min=args.learning_rate_min)
optimizer, float(cfg.OPTIM.MAX_EPOCH), eta_min=cfg.OPTIM.MIN_LR
)

architect = Architect(model, cfg)

architect = Architect(model, args)
train_meter = meters.TrainMeter(len(train_loader))
val_meter = meters.TestMeter(len(valid_loader))

for epoch in range(args.epochs):
# train_timer = Timer()
for current_epoch in range(cfg.OPTIM.MAX_EPOCH):
lr = scheduler.get_lr()[0]
logging.info('epoch %d lr %e', epoch, lr)
logger.info("epoch %d lr %e", current_epoch, lr)

genotype = model.genotype()
logging.info('genotype = %s', genotype)
logger.info("genotype = %s", genotype)
model.show_arch_parameters(logger)

# training
train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch)
logging.info('train_acc %f', train_acc)
# train_timer.tic()
top1err = train_epoch(
train_loader,
valid_loader,
model,
architect,
criterion,
optimizer,
lr,
train_meter,
current_epoch,
)
logger.info("Top1 err:%f", top1err)
# train_timer.toc()
# print("epoch time:{}".format(train_timer.diff))

# validation
valid_acc, valid_obj = infer(valid_queue, model, criterion)
logging.info('valid_acc %f', valid_acc)
test_epoch(valid_loader, model, val_meter, current_epoch, writer)

if not 'debug' in args.save:
if not "debug" in cfg.OUT_DIR:
# nasbench201
result = api.query_by_arch(model.genotype())
logging.info('{:}'.format(result))
cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = distill(result)
logging.info('cifar10 train %f test %f', cifar10_train, cifar10_test)
logging.info('cifar100 train %f valid %f test %f', cifar100_train, cifar100_valid, cifar100_test)
logging.info('imagenet16 train %f valid %f test %f', imagenet16_train, imagenet16_valid, imagenet16_test)
logger.info("{:}".format(result))
(
cifar10_train,
cifar10_test,
cifar100_train,
cifar100_valid,
cifar100_test,
imagenet16_train,
imagenet16_valid,
imagenet16_test,
) = distill(result)
logger.info("cifar10 train %f test %f", cifar10_train, cifar10_test)
logger.info(
"cifar100 train %f valid %f test %f",
cifar100_train,
cifar100_valid,
cifar100_test,
)
logger.info(
"imagenet16 train %f valid %f test %f",
imagenet16_train,
imagenet16_valid,
imagenet16_test,
)

# tensorboard
writer.add_scalars('accuracy', {'train':train_acc,'valid':valid_acc}, epoch)
writer.add_scalars('loss', {'train':train_obj,'valid':valid_obj}, epoch)
writer.add_scalars('nasbench201/cifar10', {'train':cifar10_train,'test':cifar10_test}, epoch)
writer.add_scalars('nasbench201/cifar100', {'train':cifar100_train,'valid':cifar100_valid, 'test':cifar100_test}, epoch)
writer.add_scalars('nasbench201/imagenet16', {'train':imagenet16_train,'valid':imagenet16_valid, 'test':imagenet16_test}, epoch)

utils.save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'alpha': model.arch_parameters()
}, False, args.save)
writer.add_scalars(
"nasbench201/cifar10",
{"train": cifar10_train, "test": cifar10_test},
current_epoch,
)
writer.add_scalars(
"nasbench201/cifar100",
{
"train": cifar100_train,
"valid": cifar100_valid,
"test": cifar100_test,
},
current_epoch,
)
writer.add_scalars(
"nasbench201/imagenet16",
{
"train": imagenet16_train,
"valid": imagenet16_valid,
"test": imagenet16_test,
},
current_epoch,
)

utils.save_checkpoint(
{
"epoch": current_epoch + 1,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"alpha": model.arch_parameters(),
},
False,
cfg.OUT_DIR,
)

scheduler.step()
if args.method == 'gdas' or args.method == 'snas':
if cfg.DRNAS.METHOD == "gdas" or cfg.DRNAS.METHOD == "snas":
# Decrease the temperature for the gumbel softmax linearly
tau_epoch += tau_step
logging.info('tau %f', tau_epoch)
logger.info("tau %f", tau_epoch)
model.set_tau(tau_epoch)

writer.close()
# print("avg epoch time:{}".format(train_timer.average_time))
# train_timer.reset()

writer.close()

def train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch):
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()

for step, (input, target) in enumerate(train_queue):
def train_epoch(
train_loader,
valid_loader,
model,
architect,
criterion,
optimizer,
lr,
train_meter,
cur_epoch,
):
train_meter.iter_tic()
cur_step = cur_epoch * len(train_loader)
writer.add_scalar("train/lr", lr, cur_step)

valid_loader_iter = iter(valid_loader)

for cur_iter, (trn_X, trn_y) in enumerate(train_loader):
model.train()
n = input.size(0)

input = input.cuda()
target = target.cuda(non_blocking=True)
try:
(val_X, val_y) = next(valid_loader_iter)
except StopIteration:
valid_loader_iter = iter(valid_loader)
(val_X, val_y) = next(valid_loader_iter)
# Transfer the data to the current GPU device
trn_X, trn_y = trn_X.cuda(), trn_y.cuda(non_blocking=True)
val_X, val_y = val_X.cuda(), val_y.cuda(non_blocking=True)

# get a random minibatch from the search queue with replacement
input_search, target_search = next(iter(valid_queue))
input_search = input_search.cuda()
target_search = target_search.cuda(non_blocking=True)
# if epoch >= 15:
architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
architect.step(
trn_X, trn_y, val_X, val_y, lr, optimizer, unrolled=cfg.DRNAS.UNROLLED
)
optimizer.zero_grad()
architect.optimizer.zero_grad()

logits = model(input)
loss = criterion(logits, target)
logits = model(trn_X)
loss = criterion(logits, trn_y)

loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
nn.utils.clip_grad_norm_(model.parameters(), cfg.OPTIM.GRAD_CLIP)
optimizer.step()
optimizer.zero_grad()
architect.optimizer.zero_grad()

prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
objs.update(loss.data, n)
top1.update(prec1.data, n)
top5.update(prec5.data, n)

if step % args.report_freq == 0:
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
if 'debug' in args.save:
break

return top1.avg, objs.avg


def infer(valid_queue, model, criterion):
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
model.eval()

with torch.no_grad():
for step, (input, target) in enumerate(valid_queue):
input = input.cuda()
target = target.cuda(non_blocking=True)

logits = model(input)
loss = criterion(logits, target)

prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.data, n)
top1.update(prec1.data, n)
top5.update(prec5.data, n)

if step % args.report_freq == 0:
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
if 'debug' in args.save:
break
return top1.avg, objs.avg


if __name__ == '__main__':
top1_err, top5_err = meters.topk_errors(logits, trn_y, [1, 5])
loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item()
train_meter.iter_toc()

# Update and log stats
# TODO: multiply with NUM_GPUS are disabled before appling parallel
# mb_size = trn_X.size(0) * cfg.NUM_GPUS
mb_size = trn_X.size(0)
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
train_meter.log_iter_stats(cur_epoch, cur_iter)
train_meter.iter_tic()
# write to tensorboard
writer.add_scalar("train/loss", loss, cur_step)
writer.add_scalar("train/top1_error", top1_err, cur_step)
writer.add_scalar("train/top5_error", top5_err, cur_step)
cur_step += 1
# Log epoch stats
train_meter.log_epoch_stats(cur_epoch)
train_meter.reset()
return top1_err


if __name__ == "__main__":
main()

+ 241
- 267
search/DrNAS/nb201space_progressive.py View File

@@ -1,337 +1,311 @@
import os
import sys
import time
import glob
import numpy as np
import torch
import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn

import xnas.core.logging as logging
import xnas.core.config as config
import xnas.core.meters as meters
import xnas.search_space.DrNAS.utils as utils
from xnas.search_space.DrNAS.nb201space.cnn import TinyNetwork
from xnas.search_space.DrNAS.nb201space.ops import NAS_BENCH_201
from xnas.core.builders import build_loss_fun, DrNAS_builder
from xnas.core.config import cfg
from xnas.core.timer import Timer
from xnas.core.trainer import setup_env, test_epoch
from xnas.datasets.loader import construct_loader
from xnas.search_algorithm.DrNAS import Architect


from torch.utils.tensorboard import SummaryWriter
from nas_201_api import NASBench201API as API

from xnas.core.config import cfg

# Load config and check
config.load_cfg_fom_args()
config.assert_and_infer_cfg()
cfg.freeze()
# Tensorboard supplement
writer = SummaryWriter(log_dir=os.path.join(cfg.OUT_DIR, "tb"))

parser = argparse.ArgumentParser("sota")
parser.add_argument('--data', type=str, default='datapath', help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset')
parser.add_argument('--method', type=str, default='dirichlet', help='choose nas method')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
parser.add_argument('--save', type=str, default='exp', help='experiment name')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss')
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--tau_max', type=float, default=10, help='Max temperature (tau) for the gumbel softmax.')
parser.add_argument('--tau_min', type=float, default=1, help='Min temperature (tau) for the gumbel softmax.')
parser.add_argument('--k', type=int, default=4, help='init partial channel parameter')
#### regularization
parser.add_argument('--reg_type', type=str, default='l2', choices=[
'l2', 'kl'], help='regularization type, kl is implemented for dirichlet only')
parser.add_argument('--reg_scale', type=float, default=1e-3,
help='scaling factor of the regularization term, default value is proper for l2, for kl you might adjust reg_scale to match l2')
args = parser.parse_args()

args.save = '../experiments/nasbench201/{}-search-progressive-{}-{}-{}'.format(
args.method, args.save, time.strftime("%Y%m%d-%H%M%S"), args.seed)
if not args.dataset == 'cifar10':
args.save += '-' + args.dataset
if args.unrolled:
args.save += '-unrolled'
if not args.weight_decay == 3e-4:
args.save += '-weight_l2-' + str(args.weight_decay)
args.save += '-pc-' + str(args.k)

utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
writer = SummaryWriter(args.save + '/runs')


if args.dataset == 'cifar100':
n_classes = 100
elif args.dataset == 'imagenet16-120':
n_classes = 120
else:
n_classes = 10
logger = logging.get_logger(__name__)


def distill(result):
result = result.split('\n')
cifar10 = result[5].replace(' ', '').split(':')
cifar100 = result[7].replace(' ', '').split(':')
imagenet16 = result[9].replace(' ', '').split(':')

cifar10_train = float(cifar10[1].strip(',test')[-7:-2].strip('='))
cifar10_test = float(cifar10[2][-7:-2].strip('='))
cifar100_train = float(cifar100[1].strip(',valid')[-7:-2].strip('='))
cifar100_valid = float(cifar100[2].strip(',test')[-7:-2].strip('='))
cifar100_test = float(cifar100[3][-7:-2].strip('='))
imagenet16_train = float(imagenet16[1].strip(',valid')[-7:-2].strip('='))
imagenet16_valid = float(imagenet16[2].strip(',test')[-7:-2].strip('='))
imagenet16_test = float(imagenet16[3][-7:-2].strip('='))

return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test
result = result.split("\n")
cifar10 = result[5].replace(" ", "").split(":")
cifar100 = result[7].replace(" ", "").split(":")
imagenet16 = result[9].replace(" ", "").split(":")

cifar10_train = float(cifar10[1].strip(",test")[-7:-2].strip("="))
cifar10_test = float(cifar10[2][-7:-2].strip("="))
cifar100_train = float(cifar100[1].strip(",valid")[-7:-2].strip("="))
cifar100_valid = float(cifar100[2].strip(",test")[-7:-2].strip("="))
cifar100_test = float(cifar100[3][-7:-2].strip("="))
imagenet16_train = float(imagenet16[1].strip(",valid")[-7:-2].strip("="))
imagenet16_valid = float(imagenet16[2].strip(",test")[-7:-2].strip("="))
imagenet16_test = float(imagenet16[3][-7:-2].strip("="))

return (
cifar10_train,
cifar10_test,
cifar100_train,
cifar100_valid,
cifar100_test,
imagenet16_train,
imagenet16_valid,
imagenet16_test,
)


def main():
torch.set_num_threads(3)
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)

np.random.seed(args.seed)
torch.cuda.set_device(args.gpu)
setup_env()
# follow DrNAS settings.
torch.set_num_threads(3)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled = True
torch.cuda.manual_seed(args.seed)
logging.info('gpu device = %d' % args.gpu)
logging.info("args = %s", args)
if not 'debug' in args.save:
api = API('pth file path')
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()

assert args.method in ['snas', 'dirichlet', 'darts'], "method not supported."

if args.method == 'snas':

if not "debug" in cfg.OUT_DIR:
api = API("./data/NAS-Bench-201-v1_1-096897.pth")

criterion = build_loss_fun().cuda()

assert cfg.DRNAS.METHOD in ["snas", "dirichlet", "darts"], "method not supported."

if cfg.DRNAS.METHOD == "snas":
# Create the decrease step for the gumbel softmax temperature
args.epochs = 100
tau_step = (args.tau_min - args.tau_max) / args.epochs
tau_epoch = args.tau_max
model = TinyNetwork(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes,
criterion=criterion, search_space=NAS_BENCH_201, k=args.k, species='gumbel',
reg_type="l2", reg_scale=1e-3)
elif args.method == 'dirichlet':
model = TinyNetwork(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes,
criterion=criterion, search_space=NAS_BENCH_201, k=args.k, species='dirichlet',
reg_type=args.reg_type, reg_scale=args.reg_scale)
elif args.method == 'darts':
model = TinyNetwork(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes,
criterion=criterion, search_space=NAS_BENCH_201, k=args.k, species='softmax',
reg_type="l2", reg_scale=1e-3)
model = model.cuda()
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
# cfg.OPTIM.MAX_EPOCH = 100
[tau_min, tau_max] = cfg.DRNAS.TAU
# Create the decrease step for the gumbel softmax temperature
tau_step = (tau_min - tau_max) / cfg.OPTIM.MAX_EPOCH
tau_epoch = tau_max

model = DrNAS_builder().cuda()

logger.info("param size = %fMB", utils.count_parameters_in_MB(model))

optimizer = torch.optim.SGD(
model.get_weights(),
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay)

if args.dataset == 'cifar10':
train_transform, valid_transform = utils._data_transforms_cifar10(args)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
elif args.dataset == 'cifar100':
train_transform, valid_transform = utils._data_transforms_cifar100(args)
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
elif args.dataset == 'svhn':
train_transform, valid_transform = utils._data_transforms_svhn(args)
train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform)
elif args.dataset == 'imagenet16-120':
import torchvision.transforms as transforms
from xnas.datasets.imagenet16 import ImageNet16
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26, 65.09]]
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
train_transform = transforms.Compose(lists)
train_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120)
assert len(train_data) == 151700

num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(args.train_portion * num_train))

train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True)

valid_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=True)

architect = Architect(model, args)
cfg.OPTIM.BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
)

train_loader, valid_loader = construct_loader(
cfg.SEARCH.DATASET,
cfg.SEARCH.SPLIT,
cfg.SEARCH.BATCH_SIZE,
cfg.SEARCH.DATAPATH,
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
)

architect = Architect(model, cfg)

# configure progressive parameter
epoch = 0
ks = [4, 2]
num_keeps = [5, 3]
train_epochs = [2, 2] if 'debug' in args.save else [50, 50]
train_epochs = [2, 2] if "debug" in cfg.OUT_DIR else [50, 50]

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(sum(train_epochs)), eta_min=args.learning_rate_min)
optimizer, float(sum(train_epochs)), eta_min=cfg.OPTIM.MIN_LR
)

train_meter = meters.TrainMeter(len(train_loader))
val_meter = meters.TestMeter(len(valid_loader))

# train_timer = Timer()
for i, current_epochs in enumerate(train_epochs):
for e in range(current_epochs):
lr = scheduler.get_lr()[0]
logging.info('epoch %d lr %e', epoch, lr)
logger.info("epoch %d lr %e", epoch, lr)
genotype = model.genotype()
logging.info('genotype = %s', genotype)
logger.info("genotype = %s", genotype)
model.show_arch_parameters(logger)

# training
train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, e)
logging.info('train_acc %f', train_acc)
# train_timer.tic()
top1err = train_epoch(
train_loader,
valid_loader,
model,
architect,
criterion,
optimizer,
lr,
train_meter,
e,
)
logger.info("Top1 err:%f", top1err)
# train_timer.toc()
# print("epoch time:{}".format(train_timer.diff))

# validation
valid_acc, valid_obj = infer(valid_queue, model, criterion)
logging.info('valid_acc %f', valid_acc)
test_epoch(valid_loader, model, val_meter, epoch, writer)

if not 'debug' in args.save:
if not "debug" in cfg.OUT_DIR:
# nasbench201
result = api.query_by_arch(model.genotype())
logging.info('{:}'.format(result))
cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = distill(result)
logging.info('cifar10 train %f test %f', cifar10_train, cifar10_test)
logging.info('cifar100 train %f valid %f test %f', cifar100_train, cifar100_valid, cifar100_test)
logging.info('imagenet16 train %f valid %f test %f', imagenet16_train, imagenet16_valid, imagenet16_test)
logger.info("{:}".format(result))
(
cifar10_train,
cifar10_test,
cifar100_train,
cifar100_valid,
cifar100_test,
imagenet16_train,
imagenet16_valid,
imagenet16_test,
) = distill(result)
logger.info("cifar10 train %f test %f", cifar10_train, cifar10_test)
logger.info(
"cifar100 train %f valid %f test %f",
cifar100_train,
cifar100_valid,
cifar100_test,
)
logger.info(
"imagenet16 train %f valid %f test %f",
imagenet16_train,
imagenet16_valid,
imagenet16_test,
)

# tensorboard
writer.add_scalars('accuracy', {'train':train_acc,'valid':valid_acc}, epoch)
writer.add_scalars('loss', {'train':train_obj,'valid':valid_obj}, epoch)
writer.add_scalars('nasbench201/cifar10', {'train':cifar10_train,'test':cifar10_test}, epoch)
writer.add_scalars('nasbench201/cifar100', {'train':cifar100_train,'valid':cifar100_valid, 'test':cifar100_test}, epoch)
writer.add_scalars('nasbench201/imagenet16', {'train':imagenet16_train,'valid':imagenet16_valid, 'test':imagenet16_test}, epoch)

utils.save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'alpha': model.arch_parameters()
}, False, args.save)
writer.add_scalars(
"nasbench201/cifar10",
{"train": cifar10_train, "test": cifar10_test},
epoch,
)
writer.add_scalars(
"nasbench201/cifar100",
{
"train": cifar100_train,
"valid": cifar100_valid,
"test": cifar100_test,
},
epoch,
)
writer.add_scalars(
"nasbench201/imagenet16",
{
"train": imagenet16_train,
"valid": imagenet16_valid,
"test": imagenet16_test,
},
epoch,
)

utils.save_checkpoint(
{
"epoch": epoch + 1,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"alpha": model.arch_parameters(),
},
False,
cfg.OUT_DIR,
)

epoch += 1
scheduler.step()
if args.method == 'snas':
if cfg.DRNAS.METHOD == "snas":
# Decrease the temperature for the gumbel softmax linearly
tau_epoch += tau_step
logging.info('tau %f', tau_epoch)
logger.info("tau %f", tau_epoch)
model.set_tau(tau_epoch)

if not i == len(train_epochs) - 1:
model.pruning(num_keeps[i+1])
model.pruning(num_keeps[i + 1])
# architect.pruning([model._mask])
model.wider(ks[i+1])
optimizer = utils.configure_optimizer(optimizer, torch.optim.SGD(
model.get_weights(),
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay))
scheduler = utils.configure_scheduler(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(sum(train_epochs)), eta_min=args.learning_rate_min))
logging.info('pruning finish, %d ops left per edge', num_keeps[i+1])
logging.info('network wider finish, current pc parameter %d', ks[i+1])
model.wider(ks[i + 1])
optimizer = utils.configure_optimizer(
optimizer,
torch.optim.SGD(
model.get_weights(),
cfg.OPTIM.BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
),
)
scheduler = utils.configure_scheduler(
scheduler,
torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(sum(train_epochs)), eta_min=cfg.OPTIM.MIN_LR
),
)
logger.info("pruning finish, %d ops left per edge", num_keeps[i + 1])
logger.info("network wider finish, current pc parameter %d", ks[i + 1])

genotype = model.genotype()
logging.info('genotype = %s', genotype)
logger.info("genotype = %s", genotype)
model.show_arch_parameters(logger)
writer.close()


def train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch):
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()

for step, (input, target) in enumerate(train_queue):
def train_epoch(
train_loader,
valid_loader,
model,
architect,
criterion,
optimizer,
lr,
train_meter,
cur_epoch,
):
train_meter.iter_tic()
cur_step = cur_epoch * len(train_loader)
writer.add_scalar("train/lr", lr, cur_step)

valid_loader_iter = iter(valid_loader)

for cur_iter, (trn_X, trn_y) in enumerate(train_loader):
model.train()
n = input.size(0)

input = input.cuda()
target = target.cuda(non_blocking=True)

# get a random minibatch from the search queue with replacement
input_search, target_search = next(iter(valid_queue))
input_search = input_search.cuda()
target_search = target_search.cuda(non_blocking=True)
if epoch >= 10:
architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
try:
(val_X, val_y) = next(valid_loader_iter)
except StopIteration:
valid_loader_iter = iter(valid_loader)
(val_X, val_y) = next(valid_loader_iter)
# Transfer the data to the current GPU device
trn_X, trn_y = trn_X.cuda(), trn_y.cuda(non_blocking=True)
val_X, val_y = val_X.cuda(), val_y.cuda(non_blocking=True)

if cur_epoch >= 10:
architect.step(
trn_X, trn_y, val_X, val_y, lr, optimizer, unrolled=cfg.DRNAS.UNROLLED
)
optimizer.zero_grad()
architect.optimizer.zero_grad()

logits = model(input)
loss = criterion(logits, target)
logits = model(trn_X)
loss = criterion(logits, trn_y)

loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
nn.utils.clip_grad_norm_(model.parameters(), cfg.OPTIM.GRAD_CLIP)
optimizer.step()
optimizer.zero_grad()
architect.optimizer.zero_grad()

prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
objs.update(loss.data, n)
top1.update(prec1.data, n)
top5.update(prec5.data, n)

if step % args.report_freq == 0:
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
if 'debug' in args.save:
break

return top1.avg, objs.avg


def infer(valid_queue, model, criterion):
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
model.eval()

with torch.no_grad():
for step, (input, target) in enumerate(valid_queue):
input = input.cuda()
target = target.cuda(non_blocking=True)

logits = model(input)
loss = criterion(logits, target)

prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.data, n)
top1.update(prec1.data, n)
top5.update(prec5.data, n)

if step % args.report_freq == 0:
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
if 'debug' in args.save:
break

return top1.avg, objs.avg


if __name__ == '__main__':
top1_err, top5_err = meters.topk_errors(logits, trn_y, [1, 5])
loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item()
train_meter.iter_toc()

# Update and log stats
# TODO: multiply with NUM_GPUS are disabled before appling parallel
# mb_size = trn_X.size(0) * cfg.NUM_GPUS
mb_size = trn_X.size(0)
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
train_meter.log_iter_stats(cur_epoch, cur_iter)
train_meter.iter_tic()
# write to tensorboard
writer.add_scalar("train/loss", loss, cur_step)
writer.add_scalar("train/top1_error", top1_err, cur_step)
writer.add_scalar("train/top5_error", top5_err, cur_step)
cur_step += 1
# Log epoch stats
train_meter.log_epoch_stats(cur_epoch)
train_meter.reset()
return top1_err


if __name__ == "__main__":
main()

+ 24
- 19
search/PCDARTS_search.py View File

@@ -15,7 +15,8 @@ from xnas.core.timer import Timer
from xnas.core.builders import build_space, build_loss_fun, lr_scheduler_builder
from xnas.core.config import cfg
from xnas.core.trainer import setup_env, test_epoch
from xnas.datasets.cifar10 import data_transforms_cifar10
from xnas.datasets.loader import construct_loader
# from xnas.datasets.old.cifar10 import data_transforms_cifar10
from xnas.search_algorithm.PCDARTS import *
from DARTS_search import darts_load_checkpoint, darts_save_checkpoint

@@ -45,25 +46,29 @@ def pcdarts_train_model():
architect = Architect(
pcdarts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY)

# # Load dataset
# train_transform, valid_transform = data_transforms_cifar10(cutout_length=0)

# train_data = dset.CIFAR10(
# root=cfg.SEARCH.DATASET, train=True, download=True, transform=train_transform)

# num_train = len(train_data)
# indices = list(range(num_train))
# split = int(np.floor(cfg.SEARCH.SPLIT[0] * num_train))

# train_ = torch.utils.data.DataLoader(
# train_data, batch_size=cfg.SEARCH.BATCH_SIZE,
# sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
# pin_memory=True, num_workers=2)
# val_ = torch.utils.data.DataLoader(
# train_data, batch_size=cfg.SEARCH.BATCH_SIZE,
# sampler=torch.utils.data.sampler.SubsetRandomSampler(
# indices[split:num_train]),
# pin_memory=True, num_workers=2)

# Load dataset
train_transform, valid_transform = data_transforms_cifar10(cutout_length=0)

train_data = dset.CIFAR10(
root=cfg.SEARCH.DATASET, train=True, download=True, transform=train_transform)

num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(cfg.SEARCH.SPLIT[0] * num_train))

train_ = torch.utils.data.DataLoader(
train_data, batch_size=cfg.SEARCH.BATCH_SIZE,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True, num_workers=2)
val_ = torch.utils.data.DataLoader(
train_data, batch_size=cfg.SEARCH.BATCH_SIZE,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
indices[split:num_train]),
pin_memory=True, num_workers=2)
[train_, val_] = construct_loader(
cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE, cfg.SEARCH.DATAPATH)

# weights optimizer
w_optim = torch.optim.SGD(pcdarts_controller.weights(),


+ 44
- 39
tools/test_func/data_loader_test.py View File

@@ -1,49 +1,54 @@
from xnas.datasets.imagenet import XNAS_ImageFolder
from xnas.datasets.cifar10 import XNAS_Cifar10
# from xnas.datasets.old.cifar10 import XNAS_Cifar10

"""
TODO:
old dataloader_test.py is deprecated since new loader is applied.
"""

def image_folder_test():
for backend in ['torch', 'dali_cpu', 'dali_gpu', 'custom']:
print('Testing the dataloader with backend: {}'.format(backend))
dataset = XNAS_ImageFolder('/gdata/Caltech256/256_ObjectCategories',
[0.8, 0.2],
backend=backend,
dataset_name='custom')
[train_, val_] = dataset.generate_data_loader()
# def image_folder_test():
# for backend in ['torch', 'dali_cpu', 'dali_gpu']:
# print('Testing the dataloader with backend: {}'.format(backend))
# dataset = XNAS_ImageFolder('/gdata/Caltech256/256_ObjectCategories',
# [0.8, 0.2],
# backend=backend,
# dataset_name='custom')
# [train_, val_] = dataset.generate_data_loader()

for i, (inputs, labels) in enumerate(train_):
inputs = inputs.cuda()
labels = labels.cuda()
print(inputs)
print(labels)
break
for i, (inputs, labels) in enumerate(val_):
inputs = inputs.cuda()
labels = labels.cuda()
print(inputs)
print(labels)
break
print('testing passed')
# for i, (inputs, labels) in enumerate(train_):
# inputs = inputs.cuda()
# labels = labels.cuda()
# print(inputs)
# print(labels)
# break
# for i, (inputs, labels) in enumerate(val_):
# inputs = inputs.cuda()
# labels = labels.cuda()
# print(inputs)
# print(labels)
# break
# print('testing passed')


def cifar10_test():
from xnas.core.config import cfg
[train_, val_] = XNAS_Cifar10('/gdata/cifar10/cifar-10-batches-py', [0.8, 0.2])
for i, (inputs, labels) in enumerate(train_):
inputs = inputs.cuda()
labels = labels.cuda()
print(inputs)
print(labels)
break
for i, (inputs, labels) in enumerate(val_):
inputs = inputs.cuda()
labels = labels.cuda()
print(inputs)
print(labels)
break
print('testing passed')
# def cifar10_test():
# from xnas.core.config import cfg
# [train_, val_] = XNAS_Cifar10('/gdata/cifar10/cifar-10-batches-py', [0.8, 0.2])
# for i, (inputs, labels) in enumerate(train_):
# inputs = inputs.cuda()
# labels = labels.cuda()
# print(inputs)
# print(labels)
# break
# for i, (inputs, labels) in enumerate(val_):
# inputs = inputs.cuda()
# labels = labels.cuda()
# print(inputs)
# print(labels)
# break
# print('testing passed')


if __name__ == "__main__":
cifar10_test()
pass
# cifar10_test()
# image_folder_test()

+ 6
- 5
xnas/core/builders.py View File

@@ -91,17 +91,18 @@ def register_loss_fun(name, ctor):


def DrNAS_builder():
criterion = build_loss_fun()
if cfg.SPACE.NAME == 'darts':
return _DrNASCNN_DARTSspace()
return _DrNASCNN_DARTSspace(criterion)
elif cfg.SPACE.NAME == 'nasbench201':
if cfg.DRNAS.METHOD == 'gdas':
return _DrNASCNN_GDAS_nb201space()
return _DrNASCNN_GDAS_nb201space(criterion)
elif cfg.DRNAS.METHOD == 'snas':
return _DrNASCNN_nb201space('gumbel')
return _DrNASCNN_nb201space('gumbel', criterion)
elif cfg.DRNAS.METHOD == 'dirichlet':
return _DrNASCNN_nb201space('dirichlet')
return _DrNASCNN_nb201space('dirichlet', criterion)
elif cfg.DRNAS.METHOD == 'darts':
return _DrNASCNN_nb201space('softmax')
return _DrNASCNN_nb201space('softmax', criterion)
else:
raise NotImplementedError



+ 3
- 0
xnas/core/config.py View File

@@ -434,6 +434,9 @@ _C.DRNAS.REG_SCALE = 1e-3
# method for nb201 space
_C.DRNAS.METHOD = 'dirichlet'

# temperature (tau) range for gumbel softmax
_C.DRNAS.TAU = [1, 10]



def dump_cfg():


+ 0
- 160
xnas/datasets/cifar10.py View File

@@ -1,160 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""CIFAR10 dataset."""

import os
import pickle

import numpy as np
from numpy.core.defchararray import index
from numpy.lib import index_tricks
import xnas.core.logging as logging
import xnas.datasets.transforms as transforms
import torch.utils.data
from xnas.core.config import cfg

import torchvision.transforms as torch_trans


logger = logging.get_logger(__name__)

# Per-channel mean and SD values in BGR order
_MEAN = [125.3, 123.0, 113.9]
_SD = [63.0, 62.1, 66.7]

# TODO: DALI backend support

def XNAS_Cifar10(data_path, split, backend='custom', batch_size=256, num_workers=4):
"""
XNAS cifar10, generate dataloader from cifar10 train according split and beckend
not support distributed now
"""
if backend == 'custom':
train_data = Cifar10(data_path, 'train')
num_train = len(train_data)
indices = list(range(num_train))
# Shuffle data
np.random.shuffle(indices)
data_loaders = []
pre_partition = 0.
pre_index = 0
for _split in split:
current_partition = pre_partition + _split
current_index = int(num_train * current_partition)
current_indices = indices[pre_index: current_index]
assert not len(current_indices) == 0, "Length of indices is zero!"
_sampler = torch.utils.data.sampler.SubsetRandomSampler(
current_indices)
_data_loader = torch.utils.data.DataLoader(train_data,
batch_size=batch_size,
sampler=_sampler,
num_workers=num_workers,
pin_memory=True
)
data_loaders.append(_data_loader)
pre_partition = current_partition
pre_index = current_index
return data_loaders
else:
raise NotImplementedError


class Cifar10(torch.utils.data.Dataset):
"""CIFAR-10 dataset."""

def __init__(self, data_path, split):
assert os.path.exists(
data_path), "Data path '{}' not found".format(data_path)
splits = ["train", "test"]
assert split in splits, "Split '{}' not supported for cifar".format(
split)
logger.info("Constructing CIFAR-10 {}...".format(split))
self._data_path, self._split = data_path, split
self._inputs, self._labels = self._load_data()

def _load_data(self):
"""Loads data into memory."""
logger.info("{} data path: {}".format(self._split, self._data_path))
# Compute data batch names
if self._split == "train":
batch_names = ["data_batch_{}".format(i) for i in range(1, 6)]
else:
batch_names = ["test_batch"]
# Load data batches
inputs, labels = [], []
for batch_name in batch_names:
batch_path = os.path.join(self._data_path, batch_name)
with open(batch_path, "rb") as f:
data = pickle.load(f, encoding="bytes")
inputs.append(data[b"data"])
labels += data[b"labels"]
# Combine and reshape the inputs
inputs = np.vstack(inputs).astype(np.float32)
inputs = inputs.reshape(
(-1, 3, cfg.SEARCH.IM_SIZE, cfg.SEARCH.IM_SIZE))
return inputs, labels

def _prepare_im(self, im):
"""Prepares the image for network input."""
im = transforms.color_norm(im, _MEAN, _SD)
if self._split == "train":
im = transforms.horizontal_flip(im=im, p=0.5)
im = transforms.random_crop(
im=im, size=cfg.SEARCH.IM_SIZE, pad_size=4)
return im

def __getitem__(self, index):
im, label = self._inputs[index, ...].copy(), self._labels[index]
im = self._prepare_im(im)
return im, label

def __len__(self):
return self._inputs.shape[0]


def data_transforms_cifar10(cutout_length):
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

train_transform = torch_trans.Compose([
torch_trans.RandomCrop(32, padding=4),
torch_trans.RandomHorizontalFlip(),
torch_trans.ToTensor(),
torch_trans.Normalize(CIFAR_MEAN, CIFAR_STD),
])

valid_transform = torch_trans.Compose([
torch_trans.ToTensor(),
torch_trans.Normalize(CIFAR_MEAN, CIFAR_STD),
])

if cutout_length > 0:
train_transform.transforms.append(Cutout(cutout_length))

return train_transform, valid_transform


class Cutout(object):
def __init__(self, length):
self.length = length

def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)

y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)

mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask

return img

+ 3
- 2
xnas/datasets/imagenet.py View File

@@ -112,6 +112,7 @@ class XNAS_ImageFolder():
self._imdb.append({"im_path": im_path, "class": cont_id})
print("Number of images: {}".format(len(self._imdb)))
print("Number of classes: {}".format(len(self._class_ids)))
return self._imdb

def generate_data_loader(self):
indices = list(range(len(self._imdb)))
@@ -254,8 +255,8 @@ class ImageList_torch(torch.utils.data.Dataset):
# Scale -> center crop
transformer.append(torch_transforms.Resize(self.min_crop_size))
transformer.append(torch_transforms.CenterCrop(self.crop_size))
transformer.append(torch_transforms.ToTensor())
transformer.append(torch_transforms.Normalize(
transformer.append(torch_transforms.ToTensor())
transformer.append(torch_transforms.Normalize(
mean=self._bgr_normalized_mean, std=self._bgr_normalized_std))
self.transform = torch_transforms.Compose(transformer)



+ 9
- 1
xnas/datasets/imagenet16.py View File

@@ -1,10 +1,18 @@
import os, sys, hashlib, torch
import os, sys, hashlib
import numpy as np
from PIL import Image
import torch.utils.data as data
import pickle


# def ImageNet16Loader():
# mean = [x / 255 for x in [122.68, 116.66, 104.01]]
# std = [x / 255 for x in [63.22, 61.26, 65.09]]
# lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
# train_transform = transforms.Compose(lists)
# train_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120)
# assert len(train_data) == 151700

def calculate_md5(fpath, chunk_size=1024 * 1024):
md5 = hashlib.md5()
with open(fpath, "rb") as f:


+ 128
- 45
xnas/datasets/loader.py View File

@@ -1,49 +1,132 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch.utils.data as data
import torchvision.datasets as dset

"""Data loader."""
from xnas.datasets.transforms import *
from xnas.datasets.imagenet16 import ImageNet16
from xnas.datasets.imagenet import XNAS_ImageFolder

import os
SUPPORT_DATASETS = ["cifar10", "cifar100", "svhn", "imagenet16"]
# if you use datasets loaded by imagefolder, you can add it here.
IMAGEFOLDER_FORMAT = ["imagenet"]

from xnas.core.config import cfg
from xnas.datasets.cifar10 import XNAS_Cifar10
from xnas.datasets.imagenet import XNAS_ImageFolder
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.distributed import DistributedSampler
# Supported datasets
_DATASETS = {"cifar10": XNAS_Cifar10, "imagenet": XNAS_ImageFolder}
# Relative data paths to default data directory
_PATHS = {"cifar10": "cifar-10-batches-py",
"imagenet": "ImageNet2012"}
def construct_loader(dataset_name, split_list, batch_size, datapath=None):
# Default data directory (/path/pycls/pycls/datasets/data)
if cfg.DATA_LOADER.MEMORY_DATA:
_DATA_DIR = "/userhome/temp_data"
elif datapath is not None:
_DATA_DIR = datapath
def construct_loader(
name,
split,
batch_size,
datapath=None,
cutout_length=0,
num_workers=8,
use_classes=None,
backend="torch",
):
assert (name in SUPPORT_DATASETS) or (
name in IMAGEFOLDER_FORMAT
), "dataset not supported."
datapath = "./data/" + name if datapath is None else datapath
if name in SUPPORT_DATASETS:
train_data, _ = getData(name, datapath, cutout_length, use_classes)
return splitDataLoader(train_data, batch_size, split, num_workers)
else:
_DATA_DIR = "/gdata"
# Constructs the data loader for the given dataset
assert dataset_name in _DATASETS and dataset_name in _PATHS, "Dataset '{}' not supported".format(
dataset_name)
# Retrieve the data path for the dataset
data_path = os.path.join(_DATA_DIR, _PATHS[dataset_name])
print("reading data from {}".format(data_path))
# Construct the dataset
loader = _DATASETS[dataset_name](
data_path, split_list, backend=cfg.DATA_LOADER.BACKEND, batch_size=batch_size, num_workers=cfg.DATA_LOADER.NUM_WORKERS)
return loader


# def shuffle(loader, cur_epoch):
# err_str = "Sampler type '{}' not supported".format(type(loader.sampler))
# assert isinstance(loader.sampler, (RandomSampler, DistributedSampler)), err_str
# # RandomSampler handles shuffling automatically
# if isinstance(loader.sampler, DistributedSampler):
# # DistributedSampler shuffles data based on epoch
# loader.sampler.set_epoch(cur_epoch)
data_ = XNAS_ImageFolder(
datapath, split, backend, batch_size=batch_size, num_workers=num_workers
)
return data_.generate_data_loader()


def getData(name, root, cutout_length, download=True, use_classes=None):
assert name in SUPPORT_DATASETS, "dataset not support."
assert cutout_length >= 0, "cutout_length should not be less than zero."

if name == "cifar10":
train_transform, valid_transform = transforms_cifar10(cutout_length)
train_data = dset.CIFAR10(
root=root, train=True, download=download, transform=train_transform
)
test_data = dset.CIFAR10(
root=root, train=False, download=download, transform=valid_transform
)
elif name == "cifar100":
train_transform, valid_transform = transforms_cifar100(cutout_length)
train_data = dset.CIFAR100(
root=root, train=True, download=download, transform=train_transform
)
test_data = dset.CIFAR100(
root=root, train=True, download=download, transform=valid_transform
)
elif name == "svhn":
train_transform, valid_transform = transforms_svhn(cutout_length)
train_data = dset.SVHN(
root=root, split="train", download=download, transform=train_transform
)
test_data = dset.SVHN(
root=root, split="test", download=download, transform=valid_transform
)
elif name == "imagenet16":
train_transform, valid_transform = transforms_imagenet16()
train_data = ImageNet16(
root=root,
train=True,
transform=train_transform,
use_num_of_class_only=use_classes,
)
test_data = ImageNet16(
root=root,
train=False,
transform=valid_transform,
use_num_of_class_only=use_classes,
)
if use_classes == 120:
assert len(train_data) == 151700
else:
exit(0)
return train_data, test_data


def getDataLoader(
name,
root,
batch_size,
cutout_length,
num_workers=8,
download=True,
use_classes=None,
):
train_data, test_data = getData(name, root, cutout_length, download, use_classes)
train_loader = data.DataLoader(
dataset=train_data,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
)
test_loader = data.DataLoader(
dataset=test_data,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
)
return train_loader, test_loader


def splitDataLoader(data_, batch_size, split, num_workers=8):
assert 0 not in split, "illegal split list with zero."
assert sum(split) == 1, "summation of split should be one."
num_data = len(data_)
indices = list(range(num_data))
np.random.shuffle(indices)
portion = [int(sum(split[:i]) * num_data) for i in range(len(split) + 1)]

return [
data.DataLoader(
dataset=data_,
batch_size=batch_size,
sampler=data.sampler.SubsetRandomSampler(
indices[portion[i - 1] : portion[i]]
),
num_workers=num_workers,
pin_memory=True,
)
for i in range(1, len(portion))
]

+ 0
- 123
xnas/datasets/torchDataset.py View File

@@ -1,123 +0,0 @@
import numpy as np
import torch
import torch.utils.data as data
import torchvision.datasets as dset
import torchvision.transforms as transforms


def getTorchDataset(name, root, cutout_length, batch_size, num_workers=8, download=True):
assert name in ['cifar10', 'cifar100', 'svhn'], "dataset not support."
if name == 'cifar10':
train_transform, valid_transform = _transforms_cifar10(cutout_length)
train_data = dset.CIFAR10(root=root, train=True, download=download, transform=train_transform)
test_data = dset.CIFAR10(root=root, train=False, download=download, transform=valid_transform)
elif name == 'cifar100':
train_transform, valid_transform = _transforms_cifar100(cutout_length)
train_data = dset.CIFAR100(root=root, train=True, download=download, transform=train_transform)
test_data = dset.CIFAR100(root=root, train=True, download=download, transform=valid_transform)
elif name == 'svhn':
train_transform, valid_transform = _transforms_svhn(cutout_length)
train_data = dset.SVHN(root=root, split='train', download=download, transform=train_transform)
test_data = dset.SVHN(root=root, split='test', download=download, transform=train_transform)
else:
exit(0)

train_loader = data.DataLoader(
dataset=train_data,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True
)
test_loader = data.DataLoader(
dataset=test_data,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True
)
return train_loader, test_loader


class Cutout(object):
def __init__(self, length):
self.length = length

def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)

y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)

mask[y1:y2, x1:x2] = 0.0
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img


def _transforms_svhn(cutout_length):
SVHN_MEAN = [0.4377, 0.4438, 0.4728]
SVHN_STD = [0.1980, 0.2010, 0.1970]

train_transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(SVHN_MEAN, SVHN_STD),
]
)
if cutout_length > 0:
train_transform.transforms.append(Cutout(cutout_length))

valid_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(SVHN_MEAN, SVHN_STD),]
)
return train_transform, valid_transform


def _transforms_cifar100(cutout_length):
CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
CIFAR_STD = [0.2673, 0.2564, 0.2762]

train_transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]
)
if cutout_length:
train_transform.transforms.append(Cutout(cutout_length))

valid_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD),]
)
return train_transform, valid_transform


def _transforms_cifar10(cutout_length):
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

train_transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]
)
if cutout_length:
train_transform.transforms.append(Cutout(cutout_length))

valid_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD),]
)
return train_transform, valid_transform


+ 117
- 2
xnas/datasets/transforms.py View File

@@ -5,11 +5,11 @@

"""Image transformations."""

import math

import cv2
import math
import numpy as np
import torch
import torchvision.transforms as transforms


def color_norm(im, mean, std):
@@ -128,3 +128,118 @@ def torch_lighting(im, alpha_std):
for i in range(im.shape[1]):
im[:, i, :, :] = im[:, i, :, :] + rgb[i]
return im


def transforms_svhn(cutout_length):
SVHN_MEAN = [0.4377, 0.4438, 0.4728]
SVHN_STD = [0.1980, 0.2010, 0.1970]

train_transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(SVHN_MEAN, SVHN_STD),
]
)
if cutout_length:
train_transform.transforms.append(Cutout(cutout_length))

valid_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(SVHN_MEAN, SVHN_STD),
]
)
return train_transform, valid_transform


def transforms_cifar100(cutout_length):
CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
CIFAR_STD = [0.2673, 0.2564, 0.2762]

train_transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]
)
if cutout_length:
train_transform.transforms.append(Cutout(cutout_length))

valid_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]
)
return train_transform, valid_transform


def transforms_cifar10(cutout_length):
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

train_transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]
)
if cutout_length:
train_transform.transforms.append(Cutout(cutout_length))

valid_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]
)
return train_transform, valid_transform


def transforms_imagenet16():
IMAGENET16_MEAN = [0.48109804, 0.45749020, 0.40788235]
IMAGENET16_STD = [0.24792157, 0.24023529, 0.25525490]

train_transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(16, padding=2),
transforms.ToTensor(),
transforms.Normalize(IMAGENET16_MEAN, IMAGENET16_STD),
]
)

# Cutout is not used here.

valid_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(IMAGENET16_MEAN, IMAGENET16_STD)]
)
return train_transform, valid_transform


class Cutout(object):
def __init__(self, length):
self.length = length

def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)

y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)

mask[y1:y2, x1:x2] = 0.0
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img

+ 3
- 3
xnas/search_space/DrNAS/DARTSspace/cnn.py View File

@@ -478,14 +478,14 @@ class NetworkImageNet(nn.Module):

# build API

def _DrNASCNN_DARTSspace():
def _DrNASCNN_DARTSspace(criterion):
from xnas.core.config import cfg
if cfg.SEARCH.DATASET == 'cifar10':
return NetworkCIFAR(
C=cfg.SPACE.CHANNEL,
num_classes=cfg.SEARCH.NUM_CLASSES,
layers=cfg.SPACE.LAYERS,
criterion=cfg.SEARCH.LOSS_FUN,
criterion=criterion,
k=cfg.DRNAS.K,
reg_type=cfg.DRNAS.REG_TYPE,
reg_scale=cfg.DRNAS.REG_SCALE
@@ -495,7 +495,7 @@ def _DrNASCNN_DARTSspace():
C=cfg.SPACE.CHANNEL,
num_classes=cfg.SEARCH.NUM_CLASSES,
layers=cfg.SPACE.LAYERS,
criterion=cfg.SEARCH.LOSS_FUN,
criterion=criterion,
k=cfg.DRNAS.K
)
else:


+ 4
- 4
xnas/search_space/DrNAS/nb201space/cnn.py View File

@@ -570,7 +570,7 @@ class TinyNetworkGDAS(nn.Module):

# build API

def _DrNASCNN_nb201space(species):
def _DrNASCNN_nb201space(species, criterion):
from xnas.core.config import cfg
# if cfg.SEARCH.DATASET == 'cifar10':
return TinyNetwork(
@@ -578,7 +578,7 @@ def _DrNASCNN_nb201space(species):
N=cfg.SPACE.LAYERS,
max_nodes=cfg.SPACE.NODES,
num_classes=cfg.SEARCH.NUM_CLASSES,
criterion=cfg.SEARCH.LOSS_FUN,
criterion=criterion,
search_space=NAS_BENCH_201,
k=cfg.DRNAS.K,
species=species,
@@ -586,7 +586,7 @@ def _DrNASCNN_nb201space(species):
reg_scale=cfg.DRNAS.REG_SCALE
)

def _DrNASCNN_GDAS_nb201space():
def _DrNASCNN_GDAS_nb201space(criterion):
from xnas.core.config import cfg
# if cfg.SEARCH.DATASET == 'cifar10':
return TinyNetworkGDAS(
@@ -594,6 +594,6 @@ def _DrNASCNN_GDAS_nb201space():
N=cfg.SPACE.LAYERS,
max_nodes=cfg.SPACE.NODES,
num_classes=cfg.SEARCH.NUM_CLASSES,
criterion=cfg.SEARCH.LOSS_FUN,
criterion=criterion,
search_space=NAS_BENCH_201
)

Loading…
Cancel
Save