#22 dev

Merged
xfey merged 9 commits from dev into master 1 year ago
  1. +132
    -0
      configs/search/AttentiveNAS/eval.yaml
  2. +100
    -0
      configs/search/AttentiveNAS/train.yaml
  3. +89
    -0
      configs/search/BigNAS/eval.yaml
  4. +133
    -0
      configs/search/BigNAS/search.yaml
  5. +95
    -0
      configs/search/BigNAS/train.yaml
  6. +21
    -0
      configs/search/RMINAS/rminas_proxyless_imagenet.yaml
  7. +1
    -1
      examples/search/OFA/train_supernet.sh
  8. +271
    -0
      scripts/search/AttentiveNAS/train_supernet.py
  9. +186
    -0
      scripts/search/BigNAS/search.py
  10. +254
    -0
      scripts/search/BigNAS/train_supernet.py
  11. +5
    -5
      scripts/search/OFA/eval_supernet.py
  12. +8
    -8
      scripts/search/OFA/train_supernet.py
  13. +35
    -25
      scripts/search/RMINAS.py
  14. +3
    -3
      scripts/train/DARTS.py
  15. +1
    -1
      tests/ofa_matrices_test.py
  16. +117
    -0
      xnas/algorithms/AttentiveNAS/sampler.py
  17. +47
    -1
      xnas/algorithms/RMINAS/sampler/RF_sampling.py
  18. +1
    -1
      xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet.py
  19. +8
    -10
      xnas/algorithms/RMINAS/utils/random_data.py
  20. +12
    -1
      xnas/core/builder.py
  21. +6
    -1
      xnas/core/config.py
  22. +402
    -0
      xnas/datasets/auto_augment_tf.py
  23. +64
    -136
      xnas/datasets/imagenet.py
  24. +42
    -30
      xnas/datasets/loader.py
  25. +124
    -0
      xnas/datasets/transforms_imagenet.py
  26. +36
    -4
      xnas/runner/criterion.py
  27. +0
    -1
      xnas/runner/optimizer.py
  28. +4
    -2
      xnas/runner/scheduler.py
  29. +7
    -7
      xnas/runner/trainer.py
  30. +8
    -9
      xnas/runner/trainer_spos.py
  31. +652
    -0
      xnas/spaces/AttentiveNAS/cnn.py
  32. +653
    -0
      xnas/spaces/BigNAS/cnn.py
  33. +331
    -0
      xnas/spaces/BigNAS/dynamic_layers.py
  34. +181
    -0
      xnas/spaces/BigNAS/dynamic_ops.py
  35. +133
    -0
      xnas/spaces/BigNAS/ops.py
  36. +134
    -0
      xnas/spaces/BigNAS/utils.py
  37. +14
    -10
      xnas/spaces/OFA/ProxylessNet/cnn.py
  38. +8
    -26
      xnas/spaces/OFA/dynamic_ops.py
  39. +88
    -18
      xnas/spaces/OFA/ops.py
  40. +25
    -0
      xnas/spaces/OFA/utils.py
  41. +271
    -0
      xnas/spaces/ProxylessNAS/cnn.py

+ 132
- 0
configs/search/AttentiveNAS/eval.yaml View File

@@ -0,0 +1,132 @@
NUM_GPUS: 4
RNG_SEED: 2
SPACE:
NAME: 'attentivenas'
LOADER:
DATASET: 'imagenet'
NUM_CLASSES: 1000
BATCH_SIZE: 256
NUM_WORKERS: 4
USE_VAL: True
TRANSFORM: "auto_augment_tf"
SEARCH:
IM_SIZE: 224
ATTENTIVENAS:
BN_MOMENTUM: 0.1
BN_EPS: 1.e-5
POST_BN_CALIBRATION_BATCH_NUM: 64
ACTIVE_SUBNET: # chosen from following settings
# attentive_nas_a0
RESOLUTION: 192
WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792]
KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3]
EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
DEPTH: [1, 3, 3, 3, 3, 3, 1]

# # attentive_nas_a1
# RESOLUTION: 224
# WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1984]
# KERNEL_SIZE: [3, 3, 3, 5, 3, 5, 3]
# EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
# DEPTH: [1, 3, 3, 3, 3, 3, 1]

# # attentive_nas_a2
# RESOLUTION: 224
# WIDTH: [16, 16, 24, 32, 64, 112, 200, 224, 1984]
# KERNEL_SIZE: [3, 3, 3, 3, 3, 5, 3]
# EXPAND_RATIO: [1, 4, 5, 4, 4, 6, 6]
# DEPTH: [1, 3, 3, 3, 3, 4, 1]

# # attentive_nas_a3
# RESOLUTION: 224
# WIDTH: [16, 16, 24, 32, 64, 112, 208, 224, 1984]
# KERNEL_SIZE: [3, 3, 3, 5, 3, 3, 3]
# EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
# DEPTH: [2, 3, 3, 4, 3, 5, 1]

# # attentive_nas_a4
# RESOLUTION: 256
# WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1984]
# KERNEL_SIZE: [3, 3, 3, 5, 3, 5, 3]
# EXPAND_RATIO: [1, 4, 4, 5, 4, 6, 6]
# DEPTH: [1, 3, 3, 4, 3, 5, 1]

# # attentive_nas_a5
# RESOLUTION: 256
# WIDTH: [16, 16, 24, 32, 72, 112, 192, 216, 1792]
# KERNEL_SIZE: [3, 3, 3, 5, 3, 3, 3]
# EXPAND_RATIO: [1, 4, 5, 4, 4, 6, 6]
# DEPTH: [1, 3, 3, 3, 4, 6, 1]

# # attentive_nas_a6
# RESOLUTION: 288
# WIDTH: [16, 16, 24, 32, 64, 112, 216, 224, 1984]
# KERNEL_SIZE: [3, 3, 3, 3, 3, 5, 3]
# EXPAND_RATIO: [1, 4, 6, 5, 4, 6, 6]
# DEPTH: [1, 3, 3, 4, 4, 6, 1]
SUPERNET_CFG:
use_v3_head: True
resolutions: [192, 224, 256, 288]
first_conv:
c: [16, 24]
act_func: 'swish'
s: 2
mb1:
c: [16, 24]
d: [1, 2]
k: [3, 5]
t: [1]
s: 1
act_func: 'swish'
se: False
mb2:
c: [24, 32]
d: [3, 4, 5]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb3:
c: [32, 40]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: True
mb4:
c: [64, 72]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb5:
c: [112, 120, 128]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [4, 5, 6]
s: 1
act_func: 'swish'
se: True
mb6:
c: [192, 200, 208, 216]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [6]
s: 2
act_func: 'swish'
se: True
mb7:
c: [216, 224]
d: [1, 2]
k: [3, 5]
t: [6]
s: 1
act_func: 'swish'
se: True
last_conv:
c: [1792, 1984]
act_func: 'swish'

+ 100
- 0
configs/search/AttentiveNAS/train.yaml View File

@@ -0,0 +1,100 @@
NUM_GPUS: 4
RNG_SEED: 0
SPACE:
NAME: 'attentivenas'
LOADER:
DATASET: 'imagenet'
NUM_CLASSES: 1000
BATCH_SIZE: 64 # 32*8 in total
NUM_WORKERS: 4
USE_VAL: True
TRANSFORM: "auto_augment_tf"
OPTIM:
GRAD_CLIP: 1.
WARMUP_EPOCH: 5
MAX_EPOCH: 360
WEIGHT_DECAY: 1.e-5
BASE_LR: 0.2
NESTEROV: True
SEARCH:
LOSS_FUN: "cross_entropy_smooth"
LABEL_SMOOTH: 0.1
TRAIN:
DROP_PATH_PROB: 0.2
ATTENTIVENAS:
SANDWICH_NUM: 4 # max + 2*middle + min
DROP_CONNECT: 0.2
BN_MOMENTUM: 0.
BN_EPS: 1.e-5
POST_BN_CALIBRATION_BATCH_NUM: 64
SAMPLER:
METHOD: 'bestup'
MAP_PATH: 'xnas/algorithms/AttentiveNAS/flops_archs_off_table.map'
DISCRETIZE_STEP: 25
NUM_TRIALS: 3
SUPERNET_CFG:
use_v3_head: True
resolutions: [192, 224, 256, 288]
first_conv:
c: [16, 24]
act_func: 'swish'
s: 2
mb1:
c: [16, 24]
d: [1, 2]
k: [3, 5]
t: [1]
s: 1
act_func: 'swish'
se: False
mb2:
c: [24, 32]
d: [3, 4, 5]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb3:
c: [32, 40]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: True
mb4:
c: [64, 72]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb5:
c: [112, 120, 128]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [4, 5, 6]
s: 1
act_func: 'swish'
se: True
mb6:
c: [192, 200, 208, 216]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [6]
s: 2
act_func: 'swish'
se: True
mb7:
c: [216, 224]
d: [1, 2]
k: [3, 5]
t: [6]
s: 1
act_func: 'swish'
se: True
last_conv:
c: [1792, 1984]
act_func: 'swish'

+ 89
- 0
configs/search/BigNAS/eval.yaml View File

@@ -0,0 +1,89 @@
NUM_GPUS: 4
RNG_SEED: 2
SPACE:
NAME: 'bignas'
LOADER:
DATASET: 'imagenet'
NUM_CLASSES: 1000
BATCH_SIZE: 128
NUM_WORKERS: 4
USE_VAL: True
TRANSFORM: "auto_augment_tf"
SEARCH:
IM_SIZE: 224
BIGNAS:
BN_MOMENTUM: 0.1
BN_EPS: 1.e-5
POST_BN_CALIBRATION_BATCH_NUM: 64
ACTIVE_SUBNET: # subnet for evaluation
RESOLUTION: 192
WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792]
KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3]
EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
DEPTH: [1, 3, 3, 3, 3, 3, 1]
SUPERNET_CFG:
use_v3_head: True
resolutions: [192, 224, 256, 288]
first_conv:
c: [16, 24]
act_func: 'swish'
s: 2
mb1:
c: [16, 24]
d: [1, 2]
k: [3, 5]
t: [1]
s: 1
act_func: 'swish'
se: False
mb2:
c: [24, 32]
d: [3, 4, 5]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb3:
c: [32, 40]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: True
mb4:
c: [64, 72]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb5:
c: [112, 120, 128]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [4, 5, 6]
s: 1
act_func: 'swish'
se: True
mb6:
c: [192, 200, 208, 216]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [6]
s: 2
act_func: 'swish'
se: True
mb7:
c: [216, 224]
d: [1, 2]
k: [3, 5]
t: [6]
s: 1
act_func: 'swish'
se: True
last_conv:
c: [1792, 1984]
act_func: 'swish'

+ 133
- 0
configs/search/BigNAS/search.yaml View File

@@ -0,0 +1,133 @@
NUM_GPUS: 1
RNG_SEED: 2
SPACE:
NAME: 'bignas'
LOADER:
DATASET: 'imagenet'
NUM_CLASSES: 1000
BATCH_SIZE: 128
NUM_WORKERS: 8
USE_VAL: True
TRANSFORM: "auto_augment_tf"
SEARCH:
IM_SIZE: 224
WEIGHTS: "exp/search/test/checkpoints/best_model_epoch_0009.pyth"
BIGNAS:
CONSTRAINT_FLOPS: 6.e+8 # 600M
NUM_MUTATE: 200
BN_MOMENTUM: 0.1
BN_EPS: 1.e-5
POST_BN_CALIBRATION_BATCH_NUM: 64
# ACTIVE_SUBNET: # subnet for evaluation
# RESOLUTION: 192
# WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792]
# KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3]
# EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
# DEPTH: [1, 3, 3, 3, 3, 3, 1]
SEARCH_CFG_SETS:
resolutions: [224, 256]
first_conv:
c: [16]
mb1:
c: [16]
d: [2]
k: [3]
t: [1]
mb2:
c: [24]
d: [3]
k: [3]
t: [5]
mb3:
c: [32]
d: [4]
k: [3]
t: [5]
mb4:
c: [64]
d: [5]
k: [3]
t: [5]
mb5:
c: [120]
d: [6]
k: [3]
t: [5]
mb6:
c: [192]
d: [6]
k: [3, 5]
t: [6]
mb7:
c: [216]
d: [2]
k: [3]
t: [6]
last_conv:
c: [1792]
SUPERNET_CFG:
use_v3_head: True
resolutions: [192, 224, 256, 288]
first_conv:
c: [16, 24]
act_func: 'swish'
s: 2
mb1:
c: [16, 24]
d: [1, 2]
k: [3, 5]
t: [1]
s: 1
act_func: 'swish'
se: False
mb2:
c: [24, 32]
d: [3, 4, 5]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb3:
c: [32, 40]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: True
mb4:
c: [64, 72]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb5:
c: [112, 120, 128]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [4, 5, 6]
s: 1
act_func: 'swish'
se: True
mb6:
c: [192, 200, 208, 216]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [6]
s: 2
act_func: 'swish'
se: True
mb7:
c: [216, 224]
d: [1, 2]
k: [3, 5]
t: [6]
s: 1
act_func: 'swish'
se: True
last_conv:
c: [1792, 1984]
act_func: 'swish'

+ 95
- 0
configs/search/BigNAS/train.yaml View File

@@ -0,0 +1,95 @@
NUM_GPUS: 4
RNG_SEED: 0
SPACE:
NAME: 'bignas'
LOADER:
DATASET: 'imagenet'
NUM_CLASSES: 1000
BATCH_SIZE: 32
NUM_WORKERS: 4
USE_VAL: True
TRANSFORM: "auto_augment_tf"
OPTIM:
GRAD_CLIP: 1.
WARMUP_EPOCH: 5
MAX_EPOCH: 360
WEIGHT_DECAY: 1.e-5
BASE_LR: 0.1
NESTEROV: True
SEARCH:
LOSS_FUN: "cross_entropy_smooth"
LABEL_SMOOTH: 0.1
TRAIN:
DROP_PATH_PROB: 0.2
BIGNAS:
SANDWICH_NUM: 4 # max + 2*middle + min
DROP_CONNECT: 0.2
BN_MOMENTUM: 0.
BN_EPS: 1.e-5
POST_BN_CALIBRATION_BATCH_NUM: 64
SUPERNET_CFG:
use_v3_head: True
resolutions: [192, 224, 256, 288]
first_conv:
c: [16, 24]
act_func: 'swish'
s: 2
mb1:
c: [16, 24]
d: [1, 2]
k: [3, 5]
t: [1]
s: 1
act_func: 'swish'
se: False
mb2:
c: [24, 32]
d: [3, 4, 5]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb3:
c: [32, 40]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: True
mb4:
c: [64, 72]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb5:
c: [112, 120, 128]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [4, 5, 6]
s: 1
act_func: 'swish'
se: True
mb6:
c: [192, 200, 208, 216]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [6]
s: 2
act_func: 'swish'
se: True
mb7:
c: [216, 224]
d: [1, 2]
k: [3, 5]
t: [6]
s: 1
act_func: 'swish'
se: True
last_conv:
c: [1792, 1984]
act_func: 'swish'

+ 21
- 0
configs/search/RMINAS/rminas_proxyless_imagenet.yaml View File

@@ -0,0 +1,21 @@
SPACE:
NAME: 'proxyless'
LOADER:
DATASET: 'imagenet'
NUM_CLASSES: 10
NUM_WORKERS: 0
BATCH_SIZE: 128
OPTIM:
BASE_LR: 0.025
MOMENTUM: 0.9
WEIGHT_DECAY: 0.0003
MAX_EPOCH: 500
TRAIN:
CHANNELS: 16
LAYERS: 8
RMINAS:
LOSS_BETA: 0.8
RF_WARMUP: 100
RF_THRESRATE: 0.05
RF_SUCC: 100
OUT_DIR: 'exp/rminas'

+ 1
- 1
examples/search/OFA/train_supernet.sh View File

@@ -1,4 +1,4 @@
OUT_NAME="OFA_trail_25"
OUT_NAME="OFA_trial_25"
TASKS="normal_1 kernel_1 depth_1 depth_2 expand_1 expand_2"

for loop in $TASKS


+ 271
- 0
scripts/search/AttentiveNAS/train_supernet.py View File

@@ -0,0 +1,271 @@
"""AttentiveNAS supernet training"""

import os
import random
import operator

import torch
import torch.nn as nn

import xnas.core.config as config
from xnas.datasets.loader import get_normal_dataloader
import xnas.logger.meter as meter
import xnas.logger.logging as logging
from xnas.core.config import cfg
from xnas.core.builder import *

# DDP
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

# AttentiveNAS
from xnas.runner.trainer import Trainer
from xnas.runner.scheduler import adjust_learning_rate_per_batch
from xnas.algorithms.AttentiveNAS.sampler import ArchSampler
from xnas.spaces.OFA.utils import list_mean
from xnas.spaces.BigNAS.utils import init_model


# Load config and check
config.load_configs()
logger = logging.get_logger(__name__)


def main(local_rank, world_size):
setup_env()
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', rank=local_rank, world_size=world_size)
# Network
net = space_builder().to(local_rank)
init_model(net)
# Loss function
criterion = criterion_builder()
soft_criterion = criterion_builder('kl_soft')
# Data loaders
[train_loader, valid_loader] = get_normal_dataloader()
# Optimizers
net_params = [
# parameters with weight decay
{"params": net.get_parameters(['bn', 'bias'], mode="exclude"), "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
# parameters without weight decay
{"params": net.get_parameters(['bn', 'bias'], mode="include") , "weight_decay": 0},
]
optimizer = optimizer_builder("SGD", net_params)
# sampler for AttentiveNAS
sampler = ArchSampler(
cfg.ATTENTIVENAS.SAMPLER.MAP_PATH, cfg.ATTENTIVENAS.SAMPLER.DISCRETIZE_STEP, net, None
)
net = DDP(net, device_ids=[local_rank], find_unused_parameters=True)
# Initialize Recorder
attnas_trainer = AttentivenasTrainer(
model=net,
criterion=criterion,
soft_criterion=soft_criterion,
sampler=sampler,
optimizer=optimizer,
lr_scheduler=None,
train_loader=train_loader,
test_loader=valid_loader,
)
# Resume
start_epoch = attnas_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0
# Training
logger.info("Start AttentiveNAS training.")
dist.barrier()
attnas_trainer.start()
for cur_epoch in range(start_epoch, cfg.OPTIM.WARMUP_EPOCH+cfg.OPTIM.MAX_EPOCH):
attnas_trainer.train_epoch(cur_epoch, rank=local_rank)
if local_rank == 0:
if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH:
attnas_trainer.validate(cur_epoch, local_rank)
attnas_trainer.finish()
dist.barrier()
torch.cuda.empty_cache()


class AttentivenasTrainer(Trainer):
"""Trainer for AttentiveNAS."""
def __init__(self, model, criterion, soft_criterion, sampler, optimizer, lr_scheduler, train_loader, test_loader):
super().__init__(model, criterion, optimizer, lr_scheduler, train_loader, test_loader)
self.sandwich_sample_num = max(2, cfg.ATTENTIVENAS.SANDWICH_NUM) # containing max & min
self.soft_criterion = soft_criterion
self.sampler = sampler

def train_epoch(self, cur_epoch, rank=0):
self.model.train()
# lr = self.lr_scheduler.get_last_lr()[0]
cur_step = cur_epoch * len(self.train_loader)
# self.writer.add_scalar('train/lr', lr, cur_step)
self.train_meter.iter_tic()
self.train_loader.sampler.set_epoch(cur_epoch) # DDP
for cur_iter, (inputs, labels) in enumerate(self.train_loader):
inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True)
# Adjust lr per iter
cur_lr = adjust_learning_rate_per_batch(
epoch=cur_epoch,
n_iter=len(self.train_loader),
iter=cur_iter,
warmup=(cur_epoch < cfg.OPTIM.WARMUP_EPOCH),
)
for param_group in self.optimizer.param_groups:
param_group["lr"] = cur_lr
# self.writer.add_scalar('train/lr', cur_lr, cur_step)
self.optimizer.zero_grad()
## Sandwich Rule ##
# Step 1. Largest network sampling & regularization
self.model.module.sample_max_subnet()
self.model.module.set_dropout_rate(cfg.TRAIN.DROP_PATH_PROB, cfg.ATTENTIVENAS.DROP_CONNECT)
preds = self.model(inputs)
loss = self.criterion(preds, labels)
loss.backward()
with torch.no_grad():
soft_logits = preds.clone().detach()
# Step 2. sample smaller networks
self.model.module.set_dropout_rate(0, 0)
for arch_id in range(1, self.sandwich_sample_num):
if arch_id == self.sandwich_sample_num - 1:
self.model.module.sample_min_subnet()
else:
if self.sampler is not None:
sampling_method = cfg.ATTENTIVENAS.SAMPLER.METHOD
if sampling_method in ['bestup', 'worstup']:
target_flops = self.sampler.sample_one_target_flops()
candidate_archs = self.sampler.sample_archs_according_to_flops(
target_flops, n_samples=cfg.ATTENTIVENAS.SAMPLER.NUM_TRIALS
)
my_pred_accs = []
for arch in candidate_archs:
self.model.module.set_active_subnet(**arch)
with torch.no_grad():
my_pred_accs.append(-1.0 * self.criterion(self.model(inputs), labels))
if sampling_method == 'bestup':
idx, _ = max(enumerate(my_pred_accs), key=operator.itemgetter(1))
else:
idx, _ = min(enumerate(my_pred_accs), key=operator.itemgetter(1))
self.model.module.set_active_subnet(**candidate_archs[idx]) #reset
else:
subnet_seed = int("%d%.3d%.3d" % (cur_step, arch_id, 0))
random.seed(subnet_seed)
self.model.module.sample_active_subnet()
preds = self.model(inputs)
if self.soft_criterion is not None:
loss = self.soft_criterion(preds, soft_logits)
else:
loss = self.criterion(preds, labels)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), cfg.OPTIM.GRAD_CLIP)
self.optimizer.step()
# calculating errors. The source code of AttentiveNAS uses statistics of the smallest network and XNAS follows.
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item()
self.train_meter.iter_toc()
self.train_meter.update_stats(top1_err, top5_err, loss, cur_lr, inputs.size(0) * cfg.NUM_GPUS)
self.train_meter.log_iter_stats(cur_epoch, cur_iter)
self.train_meter.iter_tic()
# self.writer.add_scalar('train/loss', i_loss, cur_step)
# self.writer.add_scalar('train/top1_error', i_top1err, cur_step)
# self.writer.add_scalar('train/top5_error', i_top5err, cur_step)
cur_step += 1
# Log epoch stats
self.train_meter.log_epoch_stats(cur_epoch)
self.train_meter.reset()
# self.lr_scheduler.step()
# Saving checkpoint
if rank==0 and (cur_epoch + 1) % cfg.SAVE_PERIOD == 0:
self.saving(cur_epoch)
@torch.no_grad()
def test_epoch(self, subnet, cur_epoch, rank=0):
subnet.eval()
self.test_meter.reset(True)
self.test_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(self.test_loader):
inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True)
preds = subnet(inputs)
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item()
self.test_meter.iter_toc()
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_avg()
top5_err = self.test_meter.mb_top5_err.get_win_avg()
# self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
# self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch)
# self.test_meter.reset()
return top1_err, top5_err


def validate(self, cur_epoch, rank, bn_calibration=True):
subnets_to_be_evaluated = {
'attentive_nas_min_net': {},
'attentive_nas_max_net': {},
}
top1_list, top5_list = [], []
with torch.no_grad():
for net_id in subnets_to_be_evaluated:
if net_id == 'attentive_nas_min_net':
self.model.module.sample_min_subnet()
elif net_id == 'attentive_nas_max_net':
self.model.module.sample_max_subnet()
elif net_id.startswith('attentive_nas_random_net'):
self.model.module.sample_active_subnet()
else:
self.model.module.set_active_subnet(
subnets_to_be_evaluated[net_id]['resolution'],
subnets_to_be_evaluated[net_id]['width'],
subnets_to_be_evaluated[net_id]['depth'],
subnets_to_be_evaluated[net_id]['kernel_size'],
subnets_to_be_evaluated[net_id]['expand_ratio'],
)

subnet = self.model.module.get_active_subnet()
subnet.to(rank)
logger.info("evaluating subnet {}".format(net_id))
if bn_calibration:
subnet.eval()
logger.info("Calibrating BN running statistics.")
subnet.reset_running_stats_for_calibration()
for cur_iter, (inputs, _) in enumerate(self.train_loader):
if cur_iter >= cfg.ATTENTIVENAS.POST_BN_CALIBRATION_BATCH_NUM:
break
inputs = inputs.to(rank)
subnet(inputs) # forward only
top1_err, top5_err = self.test_epoch(subnet, cur_epoch, rank)
top1_list.append(top1_err), top5_list.append(top5_err)
logger.info("Average@all_subnets top1_err:{} top5_err:{}".format(list_mean(top1_list), list_mean(top5_list)))
if self.best_err > list_mean(top1_list):
self.best_err = list_mean(top1_list)
self.saving(cur_epoch, best=True)


if __name__ == '__main__':
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '23333'
if torch.cuda.is_available():
cfg.NUM_GPUS = torch.cuda.device_count()
mp.spawn(main, nprocs=cfg.NUM_GPUS, args=(cfg.NUM_GPUS,), join=True)

+ 186
- 0
scripts/search/BigNAS/search.py View File

@@ -0,0 +1,186 @@
"""BigNAS subnet searching: Coarse-to-fine Architecture Selection"""

import numpy as np
from itertools import product

import torch

import xnas.core.config as config
import xnas.logger.meter as meter
import xnas.logger.logging as logging
from xnas.core.builder import *
from xnas.core.config import cfg
from xnas.datasets.loader import get_normal_dataloader
from xnas.logger.meter import TestMeter


# Load config and check
config.load_configs()
logger = logging.get_logger(__name__)


def get_all_subnets():
# get all subnets
all_subnets = []
subnet_sets = cfg.BIGNAS.SEARCH_CFG_SETS
stage_names = ['mb1', 'mb2', 'mb3', 'mb4', 'mb5', 'mb6', 'mb7']

mb_stage_subnets = []
for mbstage in stage_names:
mb_block_cfg = getattr(subnet_sets, mbstage)
mb_stage_subnets.append(list(product(
mb_block_cfg.c,
mb_block_cfg.d,
mb_block_cfg.k,
mb_block_cfg.t
)))

all_mb_stage_subnets = list(product(*mb_stage_subnets))

resolutions = getattr(subnet_sets, 'resolutions')
first_conv = getattr(subnet_sets, 'first_conv')
last_conv = getattr(subnet_sets, 'last_conv')

for res in resolutions:
for fc in first_conv.c:
for mb in all_mb_stage_subnets:
np_mb_choice = np.array(mb)
width = np_mb_choice[:, 0].tolist() # c
depth = np_mb_choice[:, 1].tolist() # d
kernel = np_mb_choice[:, 2].tolist() # k
expand = np_mb_choice[:, 3].tolist() # t
for lc in last_conv.c:
all_subnets.append({
'resolution': res,
'width': [fc] + width + [lc],
'depth': depth,
'kernel_size': kernel,
'expand_ratio': expand
})
return all_subnets


def main():
setup_env()
supernet = space_builder().cuda()
supernet.load_weights_from_pretrained_models(cfg.SEARCH.WEIGHTS)

[train_loader, valid_loader] = get_normal_dataloader()

test_meter = TestMeter(len(valid_loader))

all_subnets = get_all_subnets()
benchmarks = []

# Phase 1. coarse search
for k,subnet_cfg in enumerate(all_subnets):
supernet.set_active_subnet(
subnet_cfg['resolution'],
subnet_cfg['width'],
subnet_cfg['depth'],
subnet_cfg['kernel_size'],
subnet_cfg['expand_ratio'],
)
subnet = supernet.get_active_subnet().cuda()
# Validate
top1_err, top5_err = validate(subnet, train_loader, valid_loader, test_meter)
flops = supernet.compute_active_subnet_flops()

logger.info("[{}/{}] flops:{} top1_err:{} top5_err:{}".format(
k+1, len(all_subnets), flops, top1_err, top5_err
))

benchmarks.append({
'subnet_cfg': subnet_cfg,
'flops': flops,
'top1_err': top1_err,
'top5_err': top5_err
})

# Phase 2. fine-grained search
try:
best_subnet_info = list(filter(
lambda k: k['flops'] < cfg.BIGNAS.CONSTRAINT_FLOPS,
sorted(benchmarks, key=lambda d: d['top1_err'])))[0]
best_subnet_cfg = best_subnet_info['subnet_cfg']
best_subnet_top1 = best_subnet_info['top1_err']
except IndexError:
logger.info("Cannot find subnets under {} FLOPs".format(cfg.BIGNAS.CONSTRAINT_FLOPS))
exit(1)
for mutate_epoch in range(cfg.BIGNAS.NUM_MUTATE):
new_subnet_cfg = supernet.mutate_and_reset(best_subnet_cfg)
prev_cfgs = [i['subnet_cfg'] for i in benchmarks]
if new_subnet_cfg in prev_cfgs:
continue
subnet = supernet.get_active_subnet().cuda()
# Validate
top1_err, top5_err = validate(subnet, train_loader, valid_loader, test_meter)
flops = supernet.compute_active_subnet_flops()
logger.info("[{}/{}] flops:{} top1_err:{} top5_err:{}".format(
mutate_epoch+1, cfg.BIGNAS.NUM_MUTATE, flops, top1_err, top5_err
))

benchmarks.append({
'subnet_cfg': subnet_cfg,
'flops': flops,
'top1_err': top1_err,
'top5_err': top5_err
})
if flops < cfg.BIGNAS.CONSTRAINT_FLOPS and top1_err < best_subnet_top1:
best_subnet_cfg = new_subnet_cfg
best_subnet_top1 = top1_err
# Final best architecture
logger.info("="*20 + "\nMutate Finished.")
logger.info("Best Architecture:\n{}\n Best top1_err:{}".format(
best_subnet_cfg, best_subnet_top1
))


@torch.no_grad()
def validate(subnet, train_loader, valid_loader, test_meter):
# BN calibration
subnet.eval()
logger.info("Calibrating BN running statistics.")
subnet.reset_running_stats_for_calibration()
for cur_iter, (inputs, _) in enumerate(train_loader):
if cur_iter >= cfg.BIGNAS.POST_BN_CALIBRATION_BATCH_NUM:
break
inputs = inputs.cuda()
subnet(inputs) # forward only

top1_err, top5_err = test_epoch(subnet, valid_loader, test_meter)
return top1_err, top5_err


def test_epoch(subnet, test_loader, test_meter):
subnet.eval()
test_meter.reset(True)
test_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(test_loader):
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
preds = subnet(inputs)
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item()

test_meter.iter_toc()
test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
test_meter.log_iter_stats(0, cur_iter)
test_meter.iter_tic()
top1_err = test_meter.mb_top1_err.get_win_avg()
top5_err = test_meter.mb_top5_err.get_win_avg()
# self.writer.add_scalar('val/top1_error', test_meter.mb_top1_err.get_win_avg(), cur_epoch)
# self.writer.add_scalar('val/top5_error', test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats
test_meter.log_epoch_stats(0)
# test_meter.reset()
return top1_err, top5_err


if __name__ == "__main__":
main()

+ 254
- 0
scripts/search/BigNAS/train_supernet.py View File

@@ -0,0 +1,254 @@
"""AttentiveNAS supernet training"""

import os
import random

import torch
import torch.nn as nn

import xnas.core.config as config
from xnas.datasets.loader import get_normal_dataloader
import xnas.logger.meter as meter
import xnas.logger.logging as logging
from xnas.core.config import cfg
from xnas.core.builder import *

# DDP
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

# AttentiveNAS
from xnas.runner.trainer import Trainer
from xnas.runner.scheduler import adjust_learning_rate_per_batch
from xnas.spaces.OFA.utils import list_mean
from xnas.spaces.BigNAS.utils import init_model

# Load config and check
config.load_configs()
logger = logging.get_logger(__name__)


def main(local_rank, world_size):
setup_env()
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', rank=local_rank, world_size=world_size)
# Network
net = space_builder().to(local_rank)
init_model(net)
# Loss function
criterion = criterion_builder()
soft_criterion = criterion_builder('kl_soft')
# Data loaders
[train_loader, valid_loader] = get_normal_dataloader()
# Optimizers
net_params = [
# parameters with weight decay
{"params": net.get_parameters(['bn', 'bias'], mode="exclude"), "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
# parameters without weight decay
{"params": net.get_parameters(['bn', 'bias'], mode="include") , "weight_decay": 0},
]
optimizer = optimizer_builder("SGD", net_params)
# Rule: only regularize the biggest model
optimizer_no_wd = torch.optim.SGD(
net.parameters(),
cfg.OPTIM.BASE_LR,
cfg.OPTIM.MOMENTUM,
cfg.OPTIM.DAMPENING,
0, # no weight decay.
cfg.OPTIM.NESTEROV,
)
net = DDP(net, device_ids=[local_rank], find_unused_parameters=True)
# Initialize Recorder
bignas_trainer = BigNASTrainer(
model=net,
criterion=criterion,
soft_criterion=soft_criterion,
optimizer=optimizer,
optim_no_wd=optimizer_no_wd,
lr_scheduler=None,
train_loader=train_loader,
test_loader=valid_loader,
)
# Resume
start_epoch = bignas_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0
# Training
logger.info("Start BigNAS training.")
dist.barrier()
bignas_trainer.start()
for cur_epoch in range(start_epoch, cfg.OPTIM.WARMUP_EPOCH+cfg.OPTIM.MAX_EPOCH):
bignas_trainer.train_epoch(cur_epoch, rank=local_rank)
if local_rank == 0:
if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH:
bignas_trainer.validate(cur_epoch, local_rank)
bignas_trainer.finish()
dist.barrier()
torch.cuda.empty_cache()


class BigNASTrainer(Trainer):
"""Trainer for BigNAS."""
def __init__(self, model, criterion, soft_criterion, optimizer, optim_no_wd, lr_scheduler, train_loader, test_loader):
super().__init__(model, criterion, optimizer, lr_scheduler, train_loader, test_loader)
self.sandwich_sample_num = max(2, cfg.BIGNAS.SANDWICH_NUM) # containing max & min
self.soft_criterion = soft_criterion
self.optim_no_wd = optim_no_wd

def train_epoch(self, cur_epoch, rank=0):
self.model.train()
# lr = self.lr_scheduler.get_last_lr()[0]
cur_step = cur_epoch * len(self.train_loader)
# self.writer.add_scalar('train/lr', lr, cur_step)
self.train_meter.iter_tic()
self.train_loader.sampler.set_epoch(cur_epoch) # DDP
for cur_iter, (inputs, labels) in enumerate(self.train_loader):
inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True)
# Adjust lr per iter
cur_lr = adjust_learning_rate_per_batch(
epoch=cur_epoch,
n_iter=len(self.train_loader),
iter=cur_iter,
warmup=(cur_epoch < cfg.OPTIM.WARMUP_EPOCH),
)
# Rule: constrant ending
cur_lr = max(cur_lr, 0.05 * cfg.OPTIM.BASE_LR)
for param_group in self.optimizer.param_groups:
param_group["lr"] = cur_lr
# self.writer.add_scalar('train/lr', cur_lr, cur_step)
## Sandwich Rule ##
# Step 1. Largest network sampling & regularization
self.optimizer.zero_grad()
self.model.module.sample_max_subnet()
self.model.module.set_dropout_rate(cfg.TRAIN.DROP_PATH_PROB, cfg.BIGNAS.DROP_CONNECT)
preds = self.model(inputs)
loss = self.criterion(preds, labels)
loss.backward()
self.optimizer.step()
with torch.no_grad():
soft_logits = preds.clone().detach()
# Step 2. sample smaller networks
self.optim_no_wd.zero_grad()
self.model.module.set_dropout_rate(0, 0)
for arch_id in range(1, self.sandwich_sample_num):
if arch_id == self.sandwich_sample_num - 1:
self.model.module.sample_min_subnet()
else:
subnet_seed = int("%d%.3d%.3d" % (cur_step, arch_id, 0))
random.seed(subnet_seed)
self.model.module.sample_active_subnet()
preds = self.model(inputs)
if self.soft_criterion is not None:
loss = self.soft_criterion(preds, soft_logits)
else:
loss = self.criterion(preds, labels)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), cfg.OPTIM.GRAD_CLIP)
self.optim_no_wd.step()
# calculating errors. The source code of AttentiveNAS uses statistics of the smallest network and XNAS follows.
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item()
self.train_meter.iter_toc()
self.train_meter.update_stats(top1_err, top5_err, loss, cur_lr, inputs.size(0) * cfg.NUM_GPUS)
self.train_meter.log_iter_stats(cur_epoch, cur_iter)
self.train_meter.iter_tic()
# self.writer.add_scalar('train/loss', i_loss, cur_step)
# self.writer.add_scalar('train/top1_error', i_top1err, cur_step)
# self.writer.add_scalar('train/top5_error', i_top5err, cur_step)
cur_step += 1
# Log epoch stats
self.train_meter.log_epoch_stats(cur_epoch)
self.train_meter.reset()
# self.lr_scheduler.step()
# Saving checkpoint
if rank==0 and (cur_epoch + 1) % cfg.SAVE_PERIOD == 0:
self.saving(cur_epoch)
@torch.no_grad()
def test_epoch(self, subnet, cur_epoch, rank=0):
subnet.eval()
self.test_meter.reset(True)
self.test_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(self.test_loader):
inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True)
preds = subnet(inputs)
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item()
self.test_meter.iter_toc()
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_avg()
top5_err = self.test_meter.mb_top5_err.get_win_avg()
# self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
# self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch)
# self.test_meter.reset()
return top1_err, top5_err


def validate(self, cur_epoch, rank, bn_calibration=True):
subnets_to_be_evaluated = {
'bignas_min_net': {},
'bignas_max_net': {},
}
top1_list, top5_list = [], []
with torch.no_grad():
for net_id in subnets_to_be_evaluated:
if net_id == 'bignas_min_net':
self.model.module.sample_min_subnet()
elif net_id == 'bignas_max_net':
self.model.module.sample_max_subnet()
elif net_id.startswith('bignas_random_net'):
self.model.module.sample_active_subnet()
else:
self.model.module.set_active_subnet(
subnets_to_be_evaluated[net_id]['resolution'],
subnets_to_be_evaluated[net_id]['width'],
subnets_to_be_evaluated[net_id]['depth'],
subnets_to_be_evaluated[net_id]['kernel_size'],
subnets_to_be_evaluated[net_id]['expand_ratio'],
)

subnet = self.model.module.get_active_subnet()
subnet.to(rank)
logger.info("evaluating subnet {}".format(net_id))
if bn_calibration:
subnet.eval()
logger.info("Calibrating BN running statistics.")
subnet.reset_running_stats_for_calibration()
for cur_iter, (inputs, _) in enumerate(self.train_loader):
if cur_iter >= cfg.BIGNAS.POST_BN_CALIBRATION_BATCH_NUM:
break
inputs = inputs.to(rank)
subnet(inputs) # forward only
top1_err, top5_err = self.test_epoch(subnet, cur_epoch, rank)
top1_list.append(top1_err), top5_list.append(top5_err)
logger.info("Average@all_subnets top1_err:{} top5_err:{}".format(list_mean(top1_list), list_mean(top5_list)))
if self.best_err > list_mean(top1_list):
self.best_err = list_mean(top1_list)
self.saving(cur_epoch, best=True)


if __name__ == '__main__':
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '23333'
if torch.cuda.is_available():
cfg.NUM_GPUS = torch.cuda.device_count()
mp.spawn(main, nprocs=cfg.NUM_GPUS, args=(cfg.NUM_GPUS,), join=True)

+ 5
- 5
scripts/search/OFA/eval_supernet.py View File

@@ -96,7 +96,7 @@ def main():
# load_last_stage_ckpt(cfg.OFA.TASK, cfg.OFA.PHASE)
# ofa_trainer.resume() # only load the state_dict of model

# cfg.SEARCH.WEIGHTS = '/home/xfey/XNAS/exp/search/OFA_trail_25/kernel_1/checkpoints/model_epoch_0110.pyth'
# cfg.SEARCH.WEIGHTS = '/home/xfey/XNAS/exp/search/OFA_trial_25/kernel_1/checkpoints/model_epoch_0110.pyth'
cfg.SEARCH.WEIGHTS = '/home/xfey/XNAS/tests/weights/ofa_D4_E6_K357'
ofa_trainer.resume()

@@ -137,10 +137,10 @@ class OFATrainer(KDTrainer):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
top5_err = self.test_meter.mb_top5_err.get_win_median()
# self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
# self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
top5_err = self.test_meter.mb_top5_err.get_win_avg()
# self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
# self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch)
# self.test_meter.reset()


+ 8
- 8
scripts/search/OFA/train_supernet.py View File

@@ -9,6 +9,7 @@ import torch.nn as nn
import torch.nn.functional as F

import xnas.core.config as config
from xnas.datasets.loader import get_normal_dataloader
import xnas.logger.meter as meter
import xnas.logger.logging as logging
from xnas.core.config import cfg
@@ -44,7 +45,7 @@ def main(local_rank, world_size):
# Loss function
criterion = criterion_builder()
# Data loaders
[train_loader, valid_loader] = construct_loader()
[train_loader, valid_loader] = get_normal_dataloader()
# Optimizers
net_params = [
# parameters with weight decay
@@ -241,10 +242,10 @@ class OFATrainer(KDTrainer):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
top5_err = self.test_meter.mb_top5_err.get_win_median()
# self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
# self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
top5_err = self.test_meter.mb_top5_err.get_win_avg()
# self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
# self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch)
# self.test_meter.reset()
@@ -320,8 +321,8 @@ class OFATrainer(KDTrainer):
logger.info("Average@all_subnets top1_err:{} top5_err:{}".format(list_mean(top1errs), list_mean(top5errs)))
# Saving best model
if self.best_err > top1_err:
self.best_err = top1_err
if self.best_err > list_mean(top1errs):
self.best_err = list_mean(top1errs)
self.saving(cur_epoch, best=True)


@@ -331,6 +332,5 @@ if __name__ == '__main__':
if torch.cuda.is_available():
cfg.NUM_GPUS = torch.cuda.device_count()
print(cfg.NUM_GPUS)
mp.spawn(main, nprocs=cfg.NUM_GPUS, args=(cfg.NUM_GPUS,), join=True)

+ 35
- 25
scripts/search/RMINAS.py View File

@@ -4,7 +4,7 @@ import time
import numpy as np

import torch
from fvcore.nn import FlopCountAnalysis
import xnas.core.config as config
import xnas.logger.logging as logging
from xnas.core.config import cfg
@@ -38,6 +38,8 @@ def rminas_hp_builder():
RF_space = 'nasbenchmacro'
from xnas.evaluations.NASBenchMacro.evaluate import evaluate, data
api = data
elif cfg.SPACE.NAME == 'proxyless':
RF_space = 'proxyless'
# for example : arch = '00000000'
# arch = ''
# evaluate(arch)
@@ -48,7 +50,7 @@ def main():
rminas_hp_builder()
assert cfg.SPACE.NAME in ['infer_nb201', 'infer_darts',"nasbenchmacro"]
assert cfg.SPACE.NAME in ['infer_nb201', 'infer_darts',"nasbenchmacro", "proxyless"]
assert cfg.LOADER.DATASET in ['cifar10', 'cifar100', 'imagenet', 'imagenet16_120'], 'dataset error'

if cfg.LOADER.DATASET == 'cifar10':
@@ -64,7 +66,7 @@ def main():
network.load_state_dict(torch.load('xnas/algorithms/RMINAS/teacher_model/resnet101_cifar100/resnet101.pth'))

elif cfg.LOADER.DATASET == 'imagenet':
assert cfg.SPACE.NAME == 'infer_darts'
assert cfg.SPACE.NAME in ('infer_darts', 'proxyless')
logger.warning('Our method does not directly search in ImageNet.')
logger.warning('Only partial tests have been conducted, please use with caution.')
import xnas.algorithms.RMINAS.teacher_model.fbresnet_imagenet.fbresnet as fbresnet
@@ -93,8 +95,8 @@ def main():
ce_loss = torch.nn.CrossEntropyLoss(reduction='none').cuda()
more_logits = network(more_data_X)
_, indices = torch.topk(-ce_loss(more_logits, more_data_y).cpu().detach(), cfg.LOADER.BATCH_SIZE)
data_y = torch.Tensor([more_data_y[i] for i in indices]).long().cuda()
data_X = torch.Tensor([more_data_X[i].cpu().numpy() for i in indices]).cuda()
data_y = more_data_y.detach()
data_X = more_data_X.detach()
with torch.no_grad():
feature_res = network.feature_extractor(data_X)
@@ -107,6 +109,7 @@ def main():
loss_fun_log = torch.nn.CrossEntropyLoss().cuda()
def train_arch(modelinfo):
flops = None
if cfg.SPACE.NAME == 'infer_nb201':
# get arch
arch_config = {
@@ -122,6 +125,12 @@ def main():
elif cfg.SPACE.NAME == 'nasbenchmacro':
model = space_builder().cuda()
optimizer = optimizer_builder("SGD", model.parameters())
elif cfg.SPACE.NAME == 'proxyless':
model = space_builder(stage_width_list=[16, 24, 40, 80, 96, 192, 320],depth_param=modelinfo[:6],ks=modelinfo[6:27][modelinfo[6:27]>0],expand_ratio=modelinfo[27:][modelinfo[27:]>0],dropout_rate=0).cuda()
optimizer = optimizer_builder("SGD", model.parameters())
with torch.no_grad():
tensor = (torch.rand(1, 3, 224, 224).cuda(),)
flops = FlopCountAnalysis(model, tensor).total()
# lr_scheduler = lr_scheduler_builder(optimizer)

# nbm_trainer = OneShotTrainer(
@@ -150,12 +159,14 @@ def main():
optimizer.step()
epoch_losses.append(loss.detach().cpu().item())
if cur_epoch == cfg.OPTIM.MAX_EPOCH:
return loss.cpu().detach().numpy(), epoch_losses

return loss.cpu().detach().numpy(), {'epoch_losses':epoch_losses, 'flops':flops}

trained_arch_darts, trained_loss = [], []
def train_procedure(sample):
if cfg.SPACE.NAME == 'infer_nb201':
mixed_loss = train_arch(sample)[0]
mixed_loss, epoch_losses = train_arch(sample)[0]
mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss
trained_loss.append(mixed_loss)
arch_arr = sampling.nb201genostr2array(api.arch(sample))
@@ -164,17 +175,25 @@ def main():
elif cfg.SPACE.NAME == 'infer_darts':
sample_geno = geno_from_alpha(sampling.darts_sug2alpha(sample)) # type=Genotype
trained_arch_darts.append(str(sample_geno))
mixed_loss = train_arch(sample_geno)[0]
mixed_loss, epoch_losses = train_arch(sample_geno)[0]
mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss
trained_loss.append(mixed_loss)
RFS.trained_arch.append({'arch':sample, 'loss':mixed_loss})
elif cfg.SPACE.NAME == 'nasbenchmacro':
sample_geno = ''.join(sample.astype('str')) # type=Genotype
trained_arch_darts.append((sample_geno))
mixed_loss, epoch_losses = train_arch(sample)
mixed_loss, info = train_arch(sample)
mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss
trained_loss.append(mixed_loss)
RFS.trained_arch.append({'arch':sample, 'loss':mixed_loss,'gt':api[sample_geno]['mean_acc'],'losses':info["epoch_losses"]})
elif cfg.SPACE.NAME == 'proxyless':
sample_geno = ''.join(sample.astype('str')) # type=Genotype
trained_arch_darts.append((sample_geno))
mixed_loss, info = train_arch(sample)
mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss
trained_loss.append(mixed_loss)
RFS.trained_arch.append({'arch':sample, 'loss':mixed_loss,'gt':api[sample_geno]['mean_acc'],'losses':epoch_losses})
RFS.trained_arch.append({'arch':sample, 'loss':mixed_loss,'gt':info["flops"],'losses':info["epoch_losses"]})

logger.info("sample: {}, loss:{}".format(sample, mixed_loss))
@@ -185,21 +204,6 @@ def main():
for sample in warmup_samples:
train_procedure(sample)
RFS.Warmup()
logger.info('warmup time cost: {}'.format(str(time.time() - start_time)))
# with open('./rmi_nbm.pkl','wb') as f:
# pickle.dump(RFS.trained_arch,f)
# gt = np.array([_['gt'] for _ in RFS.trained_arch])
# losses = np.array([_['losses'] for _ in RFS.trained_arch])
# from scipy.stats import kendalltau
# kdts = []
# for epoch in range(losses.shape[-1]):
# kdts.append(kendalltau(gt, -losses[:, epoch]).correlation)
# import matplotlib.pyplot as plt
# plt.plot(kdts)
# plt.xlabel('epoch')
# plt.ylabel('kdt')
# plt.savefig('rmi_nbm.png')
# sys.exit()
# ====== RF Sampling ======
sampling_time = time.time()
sampling_cnt= 0
@@ -231,6 +235,12 @@ def main():
# op_geno = reformat_DARTS(geno_from_alpha(op_alpha))
logger.info('Searched architecture@top50:\n{}'.format(str(op_sample)))
print(api[op_sample]['mean_acc'])
elif cfg.SPACE.NAME == 'proxyless':
op_sample = RFS.optimal_arch(method='sum', top=100)
# op_alpha = torch.from_numpy(np.r_[op_sample, op_sample])
# op_geno = reformat_DARTS(geno_from_alpha(op_alpha))
logger.info('Searched architecture@top100:\n{}'.format(str(op_sample)))
print(api[op_sample]['mean_acc'])

if __name__ == '__main__':
main()

+ 3
- 3
scripts/train/DARTS.py View File

@@ -97,9 +97,9 @@ class Darts_Retrainer(Trainer):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch)
self.test_meter.reset()


+ 1
- 1
tests/ofa_matrices_test.py View File

@@ -2,7 +2,7 @@ import os
import torch

def test_local():
root = '/home/xfey/XNAS/exp/search/OFA_trail_25/kernel_1/checkpoints/'
root = '/home/xfey/XNAS/exp/search/OFA_trial_25/kernel_1/checkpoints/'
filename_prefix = 'model_epoch_'
filename_postfix = '.pyth'



+ 117
- 0
xnas/algorithms/AttentiveNAS/sampler.py View File

@@ -0,0 +1,117 @@
import random

def count_helper(v, flops, m):
if flops not in m:
m[flops] = {}
if v not in m[flops]:
m[flops][v] = 0
m[flops][v] += 1


def round_flops(flops, step):
return int(round(flops / step) * step)


def convert_count_to_prob(m):
if isinstance(m[list(m.keys())[0]], dict):
for k in m:
convert_count_to_prob(m[k])
else:
t = sum(m.values())
for k in m:
m[k] = 1.0 * m[k] / t


def sample_helper(flops, m):
keys = list(m[flops].keys())
probs = list(m[flops].values())
return random.choices(keys, weights=probs)[0]


def build_trasition_prob_matrix(file_handler, step):
# initlizie
prob_map = {}
prob_map['discretize_step'] = step
for k in ['flops', 'resolution', 'width', 'depth', 'kernel_size', 'expand_ratio']:
prob_map[k] = {}

cc = 0
for line in file_handler:
vals = eval(line.strip())

# discretize
flops = round_flops(vals['flops'], step)
prob_map['flops'][flops] = prob_map['flops'].get(flops, 0) + 1

# resolution
r = vals['resolution']
count_helper(r, flops, prob_map['resolution'])

for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
for idx, v in enumerate(vals[k]):
if idx not in prob_map[k]:
prob_map[k][idx] = {}
count_helper(v, flops, prob_map[k][idx])

cc += 1

# convert count to probability
for k in ['flops', 'resolution', 'width', 'depth', 'kernel_size', 'expand_ratio']:
convert_count_to_prob(prob_map[k])
prob_map['n_observations'] = cc
return prob_map



class ArchSampler():
def __init__(self, arch_to_flops_map_file_path, discretize_step, model, acc_predictor=None):
super(ArchSampler, self).__init__()

with open(arch_to_flops_map_file_path, 'r') as fp:
self.prob_map = build_trasition_prob_matrix(fp, discretize_step)

self.discretize_step = discretize_step
self.model = model

self.acc_predictor = acc_predictor

self.min_flops = min(list(self.prob_map['flops'].keys()))
self.max_flops = max(list(self.prob_map['flops'].keys()))

self.curr_sample_pool = None #TODO; architecture samples could be generated in an asynchronous way


def sample_one_target_flops(self, flops_uniform=False):
f_vals = list(self.prob_map['flops'].keys())
f_probs = list(self.prob_map['flops'].values())

if flops_uniform:
return random.choice(f_vals)
else:
return random.choices(f_vals, weights=f_probs)[0]


def sample_archs_according_to_flops(self, target_flops, n_samples=1, max_trials=100, return_flops=True, return_trials=False):
archs = []
#for _ in range(n_samples):
while len(archs) < n_samples:
for _trial in range(max_trials+1):
arch = {}
arch['resolution'] = sample_helper(target_flops, self.prob_map['resolution'])
for k in ['width', 'kernel_size', 'depth', 'expand_ratio']:
arch[k] = []
for idx in sorted(list(self.prob_map[k].keys())):
arch[k].append(sample_helper(target_flops, self.prob_map[k][idx]))
if self.model:
self.model.set_active_subnet(**arch)
flops = self.model.compute_active_subnet_flops()
if return_flops:
arch['flops'] = flops
if round_flops(flops, self.discretize_step) == target_flops:
break
else:
raise NotImplementedError
#accepte the sample anyway
archs.append(arch)
return archs


+ 47
- 1
xnas/algorithms/RMINAS/sampler/RF_sampling.py View File

@@ -35,7 +35,8 @@ class RF_suggest():
self.max_space = int(3**8)
self.num_estimator = 30
self.spaces = list(api.keys())
elif self.space == 'proxyless':
self.num_estimator = 100
self.model = RandomForestClassifier(n_estimators=self.num_estimator,random_state=seed)
def _update_lossthres(self):
@@ -74,6 +75,8 @@ class RF_suggest():
return [self._single_sample() for _ in range(num_warmup)]
elif self.space == 'nasbenchmacro':
return [self._single_sample() for _ in range(num_warmup)]
elif self.space == 'proxyless':
return [self._single_sample() for _ in range(num_warmup)]
def _single_sample(self, unique=True):
if self.space == 'nasbench201':
@@ -125,6 +128,28 @@ class RF_suggest():
else:
numeric_choice = np.random.randint(3,size=8)
return numeric_choice
elif self.space == 'proxyless':
def gen_sample():
depth = np.array(np.random.randint(1, 4+1, size=5).tolist() + [1])
anchors = depth+[0,4,8,12,16,20]
ks = np.random.choice([3,5,7], size=21)
expand_ratios = np.random.choice([3,6], size=21)
ed = 4
for anchor in anchors:
ks[anchor:ed] = 0
expand_ratios[anchor:ed] = 0
ed += 4
sample = np.concatenate([depth, ks, expand_ratios])
return sample
if unique:
while True:
sample = gen_sample()
if sample.tobytes() not in self.sampled_history:
self.sampled_history.append(sample.tobytes())
return sample
else:
sample = gen_sample()
return sample

def Warmup(self):
@@ -177,6 +202,18 @@ class RF_suggest():
for i in _sample_indexes:
if self.spaces[i] not in chace_table:
_sample_archs.append(np.array(list(self.spaces[i])).astype(int))
elif self.space == 'proxyless':
_sample_batch = np.array([self._single_sample(unique=False).ravel() for _ in range(self.batch)])
_tmp_trained_arch = [(i['arch'].tobytes()) for i in self.trained_arch]
_sample_archs = []
for i in _sample_batch:
if (i).tobytes() not in _tmp_trained_arch:
_sample_archs.append(i)
# print("sample {} archs/batch, cost time: {}".format(len(_sample_archs), time.time()-start_time))
best_id = np.argmax(self.model.predict_proba(_sample_archs)[:,1])
best_arch = _sample_archs[best_id]
return best_arch
# _sample_batch = np.array([self._single_sample(unique=True).ravel() for _ in range(self.batch)])
# _tmp_trained_arch = [str(i['arch'].ravel()) for i in self.trained_arch]
# _sample_archs = []
@@ -311,3 +348,12 @@ class RF_suggest():
op_arr = np.zeros((_tmp_np.size, 3))
op_arr[np.arange(_tmp_np.size),_tmp_np] = 1
return op_arr.argmax(-1)
elif self.space == 'proxyless':
assert method == 'sum', 'only sum is supported in mb.'
depth = estimate_archs[:, :6]
best_depth = np.eye(4)[depth].argmax(-1)+1
ks = estimate_archs[:, 6:27]//2 # {3, 5, 7}
best_ks = np.eye(3)[ks].argmax(-1) * 2 + 3
er = estimate_archs[:, 27:]//3 # {3, 6}
best_er = np.eye(2)[er].agrmax(-1) * 3 + 3
return np.concatenate([best_depth, best_ks, best_er])

+ 1
- 1
xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet.py View File

@@ -7,7 +7,7 @@ import math
import torch.utils.model_zoo as model_zoo
import torch

WEIGHT_PATH = 'teacher_model/fbresnet_imagenet/fbresnet152.pth'
WEIGHT_PATH = 'xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet152.pth'

__all__ = ['FBResNet',
#'fbresnet18', 'fbresnet34', 'fbresnet50', 'fbresnet101',


+ 8
- 10
xnas/algorithms/RMINAS/utils/random_data.py View File

@@ -1,5 +1,6 @@
import torch
import random
import numpy as np
from xnas.datasets.loader import get_normal_dataloader
from xnas.datasets.imagenet import ImageFolder

@@ -8,18 +9,15 @@ def get_random_data(batchsize, name):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if name == 'imagenet':
train_loader, _ = ImageFolder(
"./data/imagenet/ILSVRC2012_img_train/",
[0.5, 0.5],
batchsize*16,
datapath="./data/imagenet/ILSVRC2012_img_train/",
batch_size=batchsize*16,
split=[0.5, 0.5],
).generate_data_loader()
else:
train_loader, _ = get_normal_dataloader(name, batchsize*16)
target_i = random.randint(0, len(train_loader)-1)
more_data_X, more_data_y = None, None
for i, (more_data_X, more_data_y) in enumerate(train_loader):
if i == target_i:
break
more_data_X = more_data_X.to(device)
more_data_y = more_data_y.to(device)
random_idxs = np.random.randint(0, len(train_loader.dataset), size=train_loader.batch_size)
(more_data_X, more_data_y) = zip(*[train_loader.dataset[idx] for idx in random_idxs])
more_data_X = torch.stack(more_data_X, dim=0).to(device)
more_data_y = torch.Tensor(more_data_y).long().to(device)
return more_data_X, more_data_y

+ 12
- 1
xnas/core/builder.py View File

@@ -25,6 +25,7 @@ from xnas.runner.optimizer import optimizer_builder
from xnas.runner.criterion import criterion_builder
from xnas.runner.scheduler import lr_scheduler_builder


__all__ = [
'construct_loader',
'optimizer_builder',
@@ -48,9 +49,12 @@ from xnas.spaces.DrNAS.darts_cnn import _DrNAS_DARTS_CNN
from xnas.spaces.DrNAS.nb201_cnn import _DrNAS_nb201_CNN, _GDAS_nb201_CNN
from xnas.spaces.SPOS.cnn import _SPOS_CNN, _infer_SPOS_CNN
from xnas.spaces.DropNAS.cnn import _DropNASCNN
from xnas.spaces.ProxylessNAS.cnn import _MobileNetV2
from xnas.spaces.OFA.MobileNetV3.ofa_cnn import _OFAMobileNetV3
from xnas.spaces.OFA.ProxylessNet.ofa_cnn import _OFAProxylessNASNet
from xnas.spaces.OFA.ResNets.ofa_cnn import _OFAResNet
from xnas.spaces.BigNAS.cnn import _BigNAS_CNN, _infer_BigNAS_CNN
from xnas.spaces.AttentiveNAS.cnn import _AttentiveNAS_CNN, _infer_AttentiveNAS_CNN
from xnas.spaces.NASBenchMacro.cnn import _NBMacro_child_train, _NBMacro_sup_train

SUPPORTED_SPACES = {
@@ -63,15 +67,22 @@ SUPPORTED_SPACES = {
"gdas_nb201": _GDAS_nb201_CNN,
"dropnas": _DropNASCNN,
"spos": _SPOS_CNN,
"spos_nb201": _SPOS_nb201_CNN,
"nasbenchmacro": _NBMacro_sup_train,
"ofa_mbv3": _OFAMobileNetV3,
"ofa_proxyless": _OFAProxylessNASNet,
"ofa_resnet": _OFAResNet,
# models for inference
"attentivenas": _AttentiveNAS_CNN,
"bignas": _BigNAS_CNN,
# ===== models for inference =====
"infer_darts": _infer_DartsCNN,
"infer_nb201": _infer_NASBench201,
"infer_spos": _infer_SPOS_CNN,
"infer_attentivenas": _infer_AttentiveNAS_CNN,
# "infer_bignas": _infer_BigNAS_CNN,
"spos_nb201": _SPOS_nb201_CNN,
# "proxyless": _SuperProxylessNASNets,
"proxyless": _MobileNetV2,
}




+ 6
- 1
xnas/core/config.py View File

@@ -37,6 +37,9 @@ _C.LOADER.PIN_MEMORY = True
# _C.LOADER.BATCH_SIZE = [256, 128]
_C.LOADER.BATCH_SIZE = 256

# augment type using by ImageNet only
# chosen from ['default', 'auto_augment_tf']
_C.LOADER.TRANSFORM = "default"


# ------------------------------------------------------------------------------------ #
@@ -150,7 +153,9 @@ _C.TEST = CfgNode(new_allowed=True)

_C.TEST.IM_SIZE = 224

_C.TEST.BATCH_SIZE = 128
# using specific batchsize for testing
# using search.batch_size if this value keeps -1
_C.TEST.BATCH_SIZE = -1





+ 402
- 0
xnas/datasets/auto_augment_tf.py View File

@@ -0,0 +1,402 @@
""" Auto Augment
Implementation adapted from timm: https://github.com/rwightman/pytorch-image-models
"""

import random
import math
from PIL import Image, ImageOps, ImageEnhance
import PIL


_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])

_FILL = (128, 128, 128)

# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
_MAX_LEVEL = 10.

_HPARAMS_DEFAULT = dict(
translate_const=250,
img_mean=_FILL,
)

_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC)


def _interpolation(kwargs):
interpolation = kwargs.pop('resample', Image.NEAREST)
if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation)
else:
return interpolation


def _check_args_tf(kwargs):
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
kwargs.pop('fillcolor')
kwargs['resample'] = _interpolation(kwargs)


def shear_x(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)


def shear_y(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)


def translate_x_rel(img, pct, **kwargs):
pixels = pct * img.size[0]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)


def translate_y_rel(img, pct, **kwargs):
pixels = pct * img.size[1]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)


def translate_x_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)


def translate_y_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)


def rotate(img, degrees, **kwargs):
_check_args_tf(kwargs)
if _PIL_VER >= (5, 2):
return img.rotate(degrees, **kwargs)
elif _PIL_VER >= (5, 0):
w, h = img.size
post_trans = (0, 0)
rotn_center = (w / 2.0, h / 2.0)
angle = -math.radians(degrees)
matrix = [
round(math.cos(angle), 15),
round(math.sin(angle), 15),
0.0,
round(-math.sin(angle), 15),
round(math.cos(angle), 15),
0.0,
]

def transform(x, y, matrix):
(a, b, c, d, e, f) = matrix
return a * x + b * y + c, d * x + e * y + f

matrix[2], matrix[5] = transform(
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
)
matrix[2] += rotn_center[0]
matrix[5] += rotn_center[1]
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
else:
return img.rotate(degrees, resample=kwargs['resample'])


def auto_contrast(img, **__):
return ImageOps.autocontrast(img)


def invert(img, **__):
return ImageOps.invert(img)


def equalize(img, **__):
return ImageOps.equalize(img)


def solarize(img, thresh, **__):
return ImageOps.solarize(img, thresh)


def solarize_add(img, add, thresh=128, **__):
lut = []
for i in range(256):
if i < thresh:
lut.append(min(255, i + add))
else:
lut.append(i)
if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut
return img.point(lut)
else:
return img


def posterize(img, bits_to_keep, **__):
if bits_to_keep >= 8:
return img
bits_to_keep = max(1, bits_to_keep) # prevent all 0 images
return ImageOps.posterize(img, bits_to_keep)


def contrast(img, factor, **__):
return ImageEnhance.Contrast(img).enhance(factor)


def color(img, factor, **__):
return ImageEnhance.Color(img).enhance(factor)


def brightness(img, factor, **__):
return ImageEnhance.Brightness(img).enhance(factor)


def sharpness(img, factor, **__):
return ImageEnhance.Sharpness(img).enhance(factor)


def _randomly_negate(v):
"""With 50% prob, negate the value"""
return -v if random.random() > 0.5 else v


def _rotate_level_to_arg(level):
# range [-30, 30]
level = (level / _MAX_LEVEL) * 30.
level = _randomly_negate(level)
return (level,)


def _enhance_level_to_arg(level):
# range [0.1, 1.9]
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)


def _shear_level_to_arg(level):
# range [-0.3, 0.3]
level = (level / _MAX_LEVEL) * 0.3
level = _randomly_negate(level)
return (level,)


def _translate_abs_level_to_arg(level, translate_const):
level = (level / _MAX_LEVEL) * float(translate_const)
level = _randomly_negate(level)
return (level,)

def _translate_abs_level_to_arg2(level):
level = (level / _MAX_LEVEL) * float(_HPARAMS_DEFAULT['translate_const'])
level = _randomly_negate(level)
return (level,)

def _translate_rel_level_to_arg(level):
# range [-0.45, 0.45]
level = (level / _MAX_LEVEL) * 0.45
level = _randomly_negate(level)
return (level,)


# def level_to_arg(hparams):
# return {
# 'AutoContrast': lambda level: (),
# 'Equalize': lambda level: (),
# 'Invert': lambda level: (),
# 'Rotate': _rotate_level_to_arg,
# # FIXME these are both different from original impl as I believe there is a bug,
# # not sure what is the correct alternative, hence 2 options that look better
# 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8]
# 'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0]
# 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256]
# 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110]
# 'Color': _enhance_level_to_arg,
# 'Contrast': _enhance_level_to_arg,
# 'Brightness': _enhance_level_to_arg,
# 'Sharpness': _enhance_level_to_arg,
# 'ShearX': _shear_level_to_arg,
# 'ShearY': _shear_level_to_arg,
# 'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
# 'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
# 'TranslateXRel': lambda level: _translate_rel_level_to_arg(level),
# 'TranslateYRel': lambda level: _translate_rel_level_to_arg(level),
# }


NAME_TO_OP = {
'AutoContrast': auto_contrast,
'Equalize': equalize,
'Invert': invert,
'Rotate': rotate,
'Posterize': posterize,
'Posterize2': posterize,
'Solarize': solarize,
'SolarizeAdd': solarize_add,
'Color': color,
'Contrast': contrast,
'Brightness': brightness,
'Sharpness': sharpness,
'ShearX': shear_x,
'ShearY': shear_y,
'TranslateX': translate_x_abs,
'TranslateY': translate_y_abs,
'TranslateXRel': translate_x_rel,
'TranslateYRel': translate_y_rel,
}


def pass_fn(input):
return ()


def _conversion0(input):
return (int((input / _MAX_LEVEL) * 4) + 4,)


def _conversion1(input):
return (4 - int((input / _MAX_LEVEL) * 4),)


def _conversion2(input):
return (int((input / _MAX_LEVEL) * 256),)


def _conversion3(input):
return (int((input / _MAX_LEVEL) * 110),)


class AutoAugmentOp:
def __init__(self, name, prob, magnitude, hparams={}):
self.aug_fn = NAME_TO_OP[name]
# self.level_fn = level_to_arg(hparams)[name]
if name == 'AutoContrast' or name == 'Equalize' or name == 'Invert':
self.level_fn = pass_fn
elif name == 'Rotate':
self.level_fn = _rotate_level_to_arg
elif name == 'Posterize':
self.level_fn = _conversion0
elif name == 'Posterize2':
self.level_fn = _conversion1
elif name == 'Solarize':
self.level_fn = _conversion2
elif name == 'SolarizeAdd':
self.level_fn = _conversion3
elif name == 'Color' or name == 'Contrast' or name == 'Brightness' or name == 'Sharpness':
self.level_fn = _enhance_level_to_arg
elif name == 'ShearX' or name == 'ShearY':
self.level_fn = _shear_level_to_arg
elif name == 'TranslateX' or name == 'TranslateY':
self.level_fn = _translate_abs_level_to_arg2
elif name == 'TranslateXRel' or name == 'TranslateYRel':
self.level_fn = _translate_rel_level_to_arg
else:
print("{} not recognized".format({}))
self.prob = prob
self.magnitude = magnitude
# If std deviation of magnitude is > 0, we introduce some randomness
# in the usually fixed policy and sample magnitude from normal dist
# with mean magnitude and std-dev of magnitude_std.
# NOTE This is being tested as it's not in paper or reference impl.
self.magnitude_std = 0.5 # FIXME add arg/hparam
self.kwargs = {
'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL,
'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION
}

def __call__(self, img):
if self.prob < random.random():
return img
magnitude = self.magnitude
if self.magnitude_std and self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std)
magnitude = min(_MAX_LEVEL, max(0, magnitude))
level_args = self.level_fn(magnitude)
return self.aug_fn(img, *level_args, **self.kwargs)


def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
# ImageNet policy from TPU EfficientNet impl, cannot find
# a paper reference.
policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
return pc


def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT):
# ImageNet policy from https://arxiv.org/abs/1805.09501
policy = [
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
]
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
return pc


def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT):
if name == 'original':
return auto_augment_policy_original(hparams)
elif name == 'v0':
return auto_augment_policy_v0(hparams)
else:
print("Unknown auto_augmentation policy {}".format(name))
raise AssertionError()


class AutoAugment:

def __init__(self, policy):
self.policy = policy

def __call__(self, img):
sub_policy = random.choice(self.policy)
for op in sub_policy:
img = op(img)
return img

+ 64
- 136
xnas/datasets/imagenet.py View File

@@ -1,99 +1,67 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""ImageNet dataset."""

import math
import os
import re

import numpy as np
import torch
import torch.utils.data
import torchvision.transforms as torch_transforms
from PIL import Image
from torch.utils.data.distributed import DistributedSampler

import xnas.logger.logging as logging
from xnas.core.config import cfg
from xnas.datasets.transforms import MultiSizeRandomCrop
from xnas.datasets.transforms_imagenet import get_data_transform


logger = logging.get_logger(__name__)


class ImageFolder():
def __init__(
self,
datapath,
split,
batch_size=None,
dataset_name='imagenet',
_rgb_normalized_mean=None,
_rgb_normalized_std=None,
transforms=None,
num_workers=None,
pin_memory=None,
shuffle=True
):
"""New ImageFolder
Support ImageNet only currently.
"""
def __init__(self, datapath, batch_size, split=None, use_val=False, augment_type='default', **kwargs):
datapath = './data/imagenet/' if not datapath else datapath
assert os.path.exists(datapath), "Data path '{}' not found".format(datapath)
self.use_val = cfg.LOADER.USE_VAL
self._data_path, self._split, self.dataset_name = datapath, split, dataset_name
self._rgb_normalized_mean, self._rgb_normalized_std = _rgb_normalized_mean, _rgb_normalized_std
self.num_workers = cfg.LOADER.NUM_WORKERS if num_workers is None else num_workers
self.pin_memory = cfg.LOADER.PIN_MEMORY if pin_memory is None else pin_memory
self.shuffle = shuffle
self.use_val = use_val
self.data_path = datapath
self.split = split
self.batch_size = batch_size
if not self.use_val:
assert sum(self.split) == 1, "Summation of split should be 1"
self.msrc = None
self.loader = torch.utils.data.DataLoader
# self.collate_fn = None
self.num_workers = cfg.LOADER.NUM_WORKERS
self.pin_memory = cfg.LOADER.PIN_MEMORY
self.augment_type = augment_type
self.kwargs = kwargs
if transforms is None:
im_size = cfg.SEARCH.IM_SIZE if len(cfg.SEARCH.MULTI_SIZES)==0 else cfg.SEARCH.MULTI_SIZES
self.transforms = [{'crop': 'random', 'crop_size': im_size, 'min_crop': 0.08, 'random_flip': True},
{'crop': 'center', 'crop_size': cfg.TEST.IM_SIZE, 'min_crop': -1, 'random_flip': False}] # NOTE: min_crop is not used here.
else:
self.transforms = transforms
if not self.use_val:
assert len(self.transforms) == len(self._split), "Length of split and transforms should be equal"
else:
assert len(self.transforms) == 2
assert sum(self.split) == 1, "Summation of split should be 1."
# Check if using multisize_random_crop
if len(cfg.SEARCH.MULTI_SIZES):
# setting default loader if not using MultiSizeRandomCrop
if len(cfg.SEARCH.MULTI_SIZES) == 0:
self.loader = torch.utils.data.DataLoader
else:
from xnas.datasets.utils.msrc_loader import msrc_DataLoader
self.msrc = MultiSizeRandomCrop(cfg.SEARCH.MULTI_SIZES)
self.loader = msrc_DataLoader
logger.info("Using Random MultiSize Crop, continuous={} candidate im_sizes={}".format(self.msrc.CONTINUOUS, self.msrc.CANDIDATE_SIZES))
logger.info("Using MultiSize RandomCrop, continuous={} candidate im_sizes={}".format(self.msrc.CONTINUOUS, self.msrc.CANDIDATE_SIZES))

# Read all dataset
# Acquiring transforms
logger.info("Constructing transforms")
self.train_transform, self.test_transform = self._build_transfroms()
# Read all datasets
logger.info("Constructing ImageFolder")
self._construct_imdb()

def _construct_imdb(self):
# Images are stored per class in subdirs (format: n<number>)
if not self.use_val:
split_files = os.listdir(self._data_path)
else:
split_files = os.listdir(os.path.join(self._data_path, "train"))
if self.dataset_name == "imagenet":
# imagenet format folder names
self._class_ids = sorted(
f for f in split_files if re.match(r"^n[0-9]+$", f))
self._rgb_normalized_mean = [0.485, 0.456, 0.406]
self._rgb_normalized_std = [0.229, 0.224, 0.225]
elif self.dataset_name == 'custom':
self._class_ids = sorted(
f for f in split_files if not f[0] == '.')
split_files = os.listdir(self.data_path)
else:
raise NotImplementedError
split_files = os.listdir(os.path.join(self.data_path, "train"))
# imagenet format folder names
self._class_ids = sorted(
f for f in split_files if re.match(r"^n[0-9]+$", f))

# Map class ids to contiguous ids
self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)}
@@ -102,7 +70,7 @@ class ImageFolder():
self._imdb = []
for class_id in self._class_ids:
cont_id = self._class_id_cont_id[class_id]
train_im_dir = os.path.join(self._data_path, class_id)
train_im_dir = os.path.join(self.data_path, class_id)
for im_name in os.listdir(train_im_dir):
im_path = os.path.join(train_im_dir, im_name)
if is_image_file(im_path):
@@ -112,8 +80,8 @@ class ImageFolder():
else:
self._train_imdb = []
self._val_imdb = []
train_path = os.path.join(self._data_path, "train")
val_path = os.path.join(self._data_path, "val")
train_path = os.path.join(self.data_path, "train")
val_path = os.path.join(self.data_path, "val")
for class_id in self._class_ids:
cont_id = self._class_id_cont_id[class_id]
train_im_dir = os.path.join(train_path, class_id)
@@ -129,7 +97,11 @@ class ImageFolder():
logger.info("Number of classes: {}".format(len(self._class_ids)))
logger.info("Number of TRAIN images: {}".format(len(self._train_imdb)))
logger.info("Number of VAL images: {}".format(len(self._val_imdb)))

def _build_transfroms(self):
# KWARGS for 'auto_augment_tf': policy='v0', interpolation='bilinear'
return get_data_transform(augment=self.augment_type, **self.kwargs)
def generate_data_loader(self):
if not self.use_val:
indices = list(range(len(self._imdb)))
@@ -138,21 +110,22 @@ class ImageFolder():
data_loaders = []
pre_partition = 0.
pre_index = 0
for i, _split in enumerate(self._split):
for i, _split in enumerate(self.split):
_current_partition = pre_partition + _split
_current_index = int(len(self._imdb) * _current_partition)
_current_indices = indices[pre_index: _current_index]
assert not len(_current_indices) == 0, "The length of indices is zero!"
dataset = ImageList_torch([self._imdb[j] for j in _current_indices],
self.msrc, # add support for multisize_random_crop
_rgb_normalized_mean=self._rgb_normalized_mean,
_rgb_normalized_std=self._rgb_normalized_std,
**self.transforms[i])
dataset = ImageList_torch(
[self._imdb[j] for j in _current_indices],
# using the first split only as training dataset
transform=self.train_transform if i==0 else self.test_transform
)
sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
loader = self.loader(dataset,
batch_size=self.batch_size[i],
shuffle=(False if sampler else True),
sampler=sampler,
drop_last=(True if i==0 else False),
num_workers=self.num_workers,
pin_memory=self.pin_memory)
data_loaders.append(loader)
@@ -160,82 +133,37 @@ class ImageFolder():
pre_index = _current_index
return data_loaders
else:
train_dataset = ImageList_torch(
self._train_imdb,
self.msrc,
_rgb_normalized_mean=self._rgb_normalized_mean,
_rgb_normalized_std=self._rgb_normalized_std,
**self.transforms[0]
)
sampler = DistributedSampler(train_dataset) if cfg.NUM_GPUS > 1 else None
train_dataset = ImageList_torch(self._train_imdb, self.train_transform)
train_sampler = DistributedSampler(train_dataset) if cfg.NUM_GPUS > 1 else None
train_loader = self.loader(train_dataset,
batch_size=self.batch_size[0],
shuffle=(False if sampler else True),
sampler=sampler,
num_workers=self.num_workers,
pin_memory=self.pin_memory)
val_dataset = ImageList_torch(
self._val_imdb,
self.msrc,
_rgb_normalized_mean=self._rgb_normalized_mean,
_rgb_normalized_std=self._rgb_normalized_std,
**self.transforms[1]
)
sampler = DistributedSampler(val_dataset) if cfg.NUM_GPUS > 1 else None
batch_size=self.batch_size[0],
shuffle=(False if train_sampler else True),
sampler=train_sampler,
drop_last=True,
num_workers=self.num_workers,
pin_memory=self.pin_memory)
val_dataset = ImageList_torch(self._val_imdb, self.test_transform)
val_sampler = DistributedSampler(val_dataset) if cfg.NUM_GPUS > 1 else None
valid_loader = self.loader(val_dataset,
batch_size=self.batch_size[1],
shuffle=(False if sampler else True),
sampler=sampler,
num_workers=self.num_workers,
pin_memory=self.pin_memory)
batch_size=self.batch_size[1],
shuffle=(False if val_sampler else True),
sampler=val_sampler,
drop_last=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory)
return [train_loader, valid_loader]


class ImageList_torch(torch.utils.data.Dataset):
'''
ImageList dataloader with torch backends
From https://github.com/pytorch/vision/issues/81
'''

def __init__(
self,
_list,
msrc=None,
_rgb_normalized_mean=None,
_rgb_normalized_std=None,
crop='random',
crop_size=224,
min_crop=0.08,
random_flip=False):
self._imdb = _list
self.msrc = msrc
self._bgr_normalized_mean = _rgb_normalized_mean[::-1]
self._bgr_normalized_std = _rgb_normalized_std[::-1]
self.crop = crop
self.crop_size = crop_size
self.min_crop = min_crop
self.random_flip = random_flip
def __init__(self, list, transform):
self._imdb = list
self.transform = transform
self.loader = pil_loader
self._construct_transforms()

def _construct_transforms(self):
transforms = []
if self.crop == "random":
if isinstance(self.crop_size, int):
transforms.append(torch_transforms.RandomResizedCrop(self.crop_size, scale=(self.min_crop, 1.0)))
elif isinstance(self.crop_size, list):
# using MultiSizeRandomCrop
transforms.append(self.msrc)
elif self.crop == "center":
transforms.append(torch_transforms.Resize(math.ceil(self.crop_size / 0.875))) # assert crop_size==224
transforms.append(torch_transforms.CenterCrop(self.crop_size))
# TODO: color augmentation support
# transforms.append(torch_transforms.ColorJitter(brightness=0.4, contrast=0.4,saturation=0.4, hue=0.2))
if self.random_flip:
transforms.append(torch_transforms.RandomHorizontalFlip())
transforms.append(torch_transforms.ToTensor())
transforms.append(torch_transforms.Normalize(mean=self._bgr_normalized_mean, std=self._bgr_normalized_std))
self.transform = torch_transforms.Compose(transforms)

def __getitem__(self, index):
impath = self._imdb[index]["im_path"]


+ 42
- 30
xnas/datasets/loader.py View File

@@ -26,21 +26,22 @@ def construct_loader(
cutout_length=0,
use_classes=None,
transforms=None,
**kwargs
):
"""Construct NAS dataloaders with train&valid subsets."""
split = cfg.LOADER.SPLIT
name = cfg.LOADER.DATASET
batch_size = cfg.LOADER.BATCH_SIZE
batch_size = cfg.LOADER.BATCH_SIZE
datapath = cfg.LOADER.DATAPATH
assert (name in SUPPORTED_DATASETS) or (name in IMAGEFOLDER_FORMAT), "dataset not supported."

# expand batch_size to support different number during training & validating
if isinstance(batch_size, int):
batch_size = [batch_size, batch_size]
batch_size = [batch_size] * len(split)
elif batch_size is None:
batch_size = [256, 256]
batch_size = [256] * len(split)
assert len(batch_size) == len(split), "lengths of batch_size and split should be same."
# check if randomresized crop is used only in ImageFolder type datasets
@@ -52,9 +53,11 @@ def construct_loader(
train_data, _ = get_data(name, datapath, cutout_length, use_classes=use_classes, transforms=transforms)
return split_dataloader(train_data, batch_size, split)
elif name in IMAGEFOLDER_FORMAT:
assert cfg.LOADER.USE_VAL is False, "do not using VAL dataset."
aug_type = cfg.LOADER.TRANSFORM
return ImageFolder( # using path of training data of ImageNet as `datapath`
datapath, split, batch_size=batch_size,
transforms=transforms,
datapath, batch_size=batch_size, split=split,
use_val=False, augment_type=aug_type, **kwargs
).generate_data_loader()
else:
print("dataset not supported.")
@@ -137,39 +140,48 @@ def get_normal_dataloader(
name=None,
train_batch=None,
cutout_length=0,
download=True,
use_classes=None,
transforms=None,
**kwargs
):
name=cfg.LOADER.DATASET if name is None else name
train_batch=cfg.LOADER.BATCH_SIZE if train_batch is None else train_batch
name=cfg.LOADER.DATASET
root=cfg.LOADER.DATAPATH
test_batch=cfg.TEST.BATCH_SIZE
datapath=cfg.LOADER.DATAPATH
test_batch=cfg.LOADER.BATCH_SIZE if cfg.TEST.BATCH_SIZE == -1 else cfg.TEST.BATCH_SIZE
# get normal dataloaders with train&test subsets.
train_data, test_data = get_data(name, root, cutout_length, download, use_classes, transforms)
# if loader.batch_size is a list for [train, val_1, ...], the first value will be used.
if isinstance(train_batch, list):
train_batch = train_batch[0]
train_loader = data.DataLoader(
dataset=train_data,
batch_size=train_batch,
shuffle=True,
num_workers=cfg.LOADER.NUM_WORKERS,
pin_memory=cfg.LOADER.PIN_MEMORY,
)
test_loader = data.DataLoader(
dataset=test_data,
batch_size=test_batch,
shuffle=False,
num_workers=cfg.LOADER.NUM_WORKERS,
pin_memory=cfg.LOADER.PIN_MEMORY,
)
return train_loader, test_loader
assert (name in SUPPORTED_DATASETS) or (name in IMAGEFOLDER_FORMAT), "dataset not supported."
assert isinstance(train_batch, int), "normal dataloader using single training batch-size, not list."
# check if randomresized crop is used only in ImageFolder type datasets
if len(cfg.SEARCH.MULTI_SIZES):
assert name in IMAGEFOLDER_FORMAT, "RandomResizedCrop can only be used in ImageFolder currently."

if name in SUPPORTED_DATASETS:
# get normal dataloaders with train&test subsets.
train_data, test_data = get_data(name, datapath, cutout_length, use_classes=use_classes, transforms=transforms)
train_loader = data.DataLoader(
dataset=train_data,
batch_size=train_batch,
shuffle=True,
num_workers=cfg.LOADER.NUM_WORKERS,
pin_memory=cfg.LOADER.PIN_MEMORY,
)
test_loader = data.DataLoader(
dataset=test_data,
batch_size=test_batch,
shuffle=False,
num_workers=cfg.LOADER.NUM_WORKERS,
pin_memory=cfg.LOADER.PIN_MEMORY,
)
return train_loader, test_loader
elif name in IMAGEFOLDER_FORMAT:
assert cfg.LOADER.USE_VAL is True, "getting normal dataloader."
aug_type = cfg.LOADER.TRANSFORM
return ImageFolder( # using path of training data of ImageNet as `datapath`
datapath, batch_size=[train_batch, test_batch],
use_val=True, augment_type=aug_type, **kwargs
).generate_data_loader()

def split_dataloader(data_, batch_size, split):
assert 0 not in split, "illegal split list with zero."


+ 124
- 0
xnas/datasets/transforms_imagenet.py View File

@@ -0,0 +1,124 @@
import math
import torch
from PIL import Image
import torchvision.transforms as transforms

from xnas.core.config import cfg
from xnas.datasets.auto_augment_tf import auto_augment_policy, AutoAugment


IMAGENET_RGB_MEAN = [0.485, 0.456, 0.406]
IMAGENET_RGB_STD = [0.229, 0.224, 0.225]


def get_data_transform(augment, **kwargs):
if len(cfg.SEARCH.MULTI_SIZES)==0:
# using single image_size for training
train_crop_size = cfg.SEARCH.IM_SIZE
else:
# using MultiSize_RandomCrop
train_crop_size = cfg.SEARCH.MULTI_SIZES
min_train_scale = 0.08
test_scale = math.ceil(cfg.TEST.IM_SIZE / 0.875) # 224 / 0.875 = 256
test_crop_size = cfg.TEST.IM_SIZE # do not crop and using 224 by default.

interpolation = transforms.InterpolationMode.BICUBIC
if 'interpolation' in kwargs.keys() and kwargs['interpolation'] == 'bilinear':
interpolation = transforms.InterpolationMode.BILINEAR
da_args = {
'train_crop_size': train_crop_size,
'train_min_scale': min_train_scale,
'test_scale': test_scale,
'test_crop_size': test_crop_size,
'interpolation': interpolation,
}

if augment == 'default':
return build_default_transform(**da_args)
elif augment == 'auto_augment_tf':
policy = 'v0' if 'policy' not in kwargs.keys() else kwargs['policy']
return build_imagenet_auto_augment_tf_transform(policy=policy, **da_args)
else:
raise ValueError(augment)


def get_normalize():
return transforms.Normalize(
mean=torch.Tensor(IMAGENET_RGB_MEAN),
std=torch.Tensor(IMAGENET_RGB_STD),
)


def get_randomResizedCrop(train_crop_size=224, train_min_scale=0.08, interpolation=transforms.InterpolationMode.BICUBIC):
if isinstance(train_crop_size, int):
return transforms.RandomResizedCrop(train_crop_size, scale=(train_min_scale, 1.0), interpolation=interpolation)
elif isinstance(train_crop_size, list):
from xnas.datasets.transforms import MultiSizeRandomCrop
msrc = MultiSizeRandomCrop(train_crop_size)
return msrc
else:
raise TypeError(train_crop_size)


def build_default_transform(
train_crop_size=224, train_min_scale=0.08, test_scale=256, test_crop_size=224, interpolation=transforms.InterpolationMode.BICUBIC
):
normalize = get_normalize()
train_crop_transform = get_randomResizedCrop(
train_crop_size, train_min_scale, interpolation
)
train_transform = transforms.Compose(
[
# transforms.RandomResizedCrop(train_crop_size, interpolation=interpolation),
train_crop_transform,
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
)
test_transform = transforms.Compose(
[
transforms.Resize(test_scale, interpolation=interpolation),
transforms.CenterCrop(test_crop_size),
transforms.ToTensor(),
normalize,
]
)
return train_transform, test_transform


def build_imagenet_auto_augment_tf_transform(
policy='v0', train_crop_size=224, train_min_scale=0.08, test_scale=256, test_crop_size=224, interpolation=transforms.InterpolationMode.BICUBIC
):

normalize = get_normalize()
img_size = train_crop_size
aa_params = {
"translate_const": int(img_size * 0.45),
"img_mean": tuple(round(x) for x in IMAGENET_RGB_MEAN),
}

aa_policy = AutoAugment(auto_augment_policy(policy, aa_params))
train_crop_transform = get_randomResizedCrop(
train_crop_size, train_min_scale, interpolation
)
train_transform = transforms.Compose(
[
# transforms.RandomResizedCrop(train_crop_size, interpolation=interpolation),
train_crop_transform,
transforms.RandomHorizontalFlip(),
aa_policy,
transforms.ToTensor(),
normalize,
]
)
test_transform = transforms.Compose(
[
transforms.Resize(test_scale, interpolation=interpolation),
transforms.CenterCrop(test_crop_size),
transforms.ToTensor(),
normalize,
]
)
return train_transform, test_transform

+ 36
- 4
xnas/runner/criterion.py View File

@@ -32,6 +32,35 @@ def CrossEntropyLoss_label_smoothed(pred, target, label_smoothing=0.):
return CrossEntropyLoss_soft_target(pred, soft_target)


class KLLossSoft(torch.nn.modules.loss._Loss):
""" inplace distillation for image classification
output: output logits of the student network
target: output logits of the teacher network
T: temperature
KL(p||q) = Ep \log p - \Ep log q
"""
def forward(self, output, soft_logits, target=None, temperature=1., alpha=0.9):
output, soft_logits = output / temperature, soft_logits / temperature
soft_target_prob = F.softmax(soft_logits, dim=1)
output_log_prob = F.log_softmax(output, dim=1)
kd_loss = -torch.sum(soft_target_prob * output_log_prob, dim=1)
if target is not None:
n_class = output.size(1)
target = torch.zeros_like(output).scatter(1, target.view(-1, 1), 1)
target = target.unsqueeze(1)
output_log_prob = output_log_prob.unsqueeze(2)
ce_loss = -torch.bmm(target, output_log_prob).squeeze()
loss = alpha * temperature * temperature * kd_loss + (1.0 - alpha) * ce_loss
else:
loss = kd_loss
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
return loss


class MultiHeadCrossEntropyLoss(nn.Module):
def forward(self, preds, targets):
assert preds.dim() == 3, preds
@@ -50,12 +79,15 @@ class MultiHeadCrossEntropyLoss(nn.Module):

SUPPORTED_CRITERIONS = {
"cross_entropy": torch.nn.CrossEntropyLoss(),
"cross_entropy_soft": CrossEntropyLoss_soft_target,
"cross_entropy_smooth": CrossEntropyLoss_label_smoothed,
"cross_entropy_multihead": MultiHeadCrossEntropyLoss()
"cross_entropy_multihead": MultiHeadCrossEntropyLoss(),
"kl_soft": KLLossSoft(),
}


def criterion_builder():
def criterion_builder(criterion=None):
err_str = "Loss function type '{}' not supported"
assert cfg.SEARCH.LOSS_FUN in SUPPORTED_CRITERIONS.keys(), err_str.format(cfg.SEARCH.LOSS_FUN)
return SUPPORTED_CRITERIONS[cfg.SEARCH.LOSS_FUN]
loss_fun = cfg.SEARCH.LOSS_FUN if criterion is None else criterion
assert loss_fun in SUPPORTED_CRITERIONS.keys(), err_str.format(loss_fun)
return SUPPORTED_CRITERIONS[loss_fun]

+ 0
- 1
xnas/runner/optimizer.py View File

@@ -1,7 +1,6 @@
"""Optimizers."""

import torch
import torch.nn as nn
from xnas.core.config import cfg




+ 4
- 2
xnas/runner/scheduler.py View File

@@ -85,14 +85,15 @@ class GradualWarmupScheduler(_LRScheduler):
def _calc_learning_rate(
init_lr, n_epochs, epoch, n_iter=None, iter=0,
):
if cfg.SEARCH.LOSS_FUN.startswith("cross_entropy"):
if cfg.OPTIM.LR_POLICY == "cos":
t_total = n_epochs * n_iter
t_cur = epoch * n_iter + iter
lr = 0.5 * init_lr * (1 + math.cos(math.pi * t_cur / t_total))
else:
raise ValueError("do not support: {}".format(cfg.SEARCH.LOSS_FUN))
raise ValueError("do not support: {}".format(cfg.OPTIM.LR_POLICY))
return lr


def _warmup_adjust_learning_rate(
init_lr, n_epochs, epoch, n_iter, iter=0, warmup_lr=0
):
@@ -102,6 +103,7 @@ def _warmup_adjust_learning_rate(
new_lr = T_cur / t_total * (init_lr - warmup_lr) + warmup_lr
return new_lr


def adjust_learning_rate_per_batch(epoch, n_iter=None, iter=0, warmup=False):
"""adjust learning of a given optimizer and return the new learning rate"""


+ 7
- 7
xnas/runner/trainer.py View File

@@ -114,9 +114,9 @@ class Trainer(Recorder):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch)
self.test_meter.reset()
@@ -349,9 +349,9 @@ class OneShotTrainer(Trainer):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch)
self.test_meter.reset()
@@ -372,7 +372,7 @@ class OneShotTrainer(Trainer):
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item()
self.evaluate_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
top1_err = self.evaluate_meter.mb_top1_err.get_win_median()
top1_err = self.evaluate_meter.mb_top1_err.get_win_avg()
# self.evaluate_sampler.record(choice, top1_err)
self.evaluate_meter.reset()
return top1_err


+ 8
- 9
xnas/runner/trainer_spos.py View File

@@ -114,9 +114,9 @@ class Trainer(Recorder):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch)
self.test_meter.reset()
@@ -381,10 +381,9 @@ class OneShotTrainer(Trainer):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
top1_err_avg = self.test_meter.mb_top1_err.get_global_avg()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch)
self.test_meter.reset()
@@ -392,7 +391,7 @@ class OneShotTrainer(Trainer):
if self.best_err > top1_err:
self.best_err = top1_err
self.saving(cur_epoch, best=True)
return top1_err_avg
return top1_err
@torch.no_grad()
def evaluate_epoch(self, sample):
@@ -405,7 +404,7 @@ class OneShotTrainer(Trainer):
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item()
self.evaluate_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
top1_err = self.evaluate_meter.mb_top1_err.get_win_median()
top1_err = self.evaluate_meter.mb_top1_err.get_win_avg()
# self.evaluate_sampler.record(choice, top1_err)
self.evaluate_meter.reset()
return top1_err


+ 652
- 0
xnas/spaces/AttentiveNAS/cnn.py View File

@@ -0,0 +1,652 @@
# Implementation adapted from AttentiveNAS: https://github.com/facebookresearch/AttentiveNAS

import random
from copy import deepcopy
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from xnas.spaces.OFA.ops import ResidualBlock
from xnas.spaces.OFA.dynamic_ops import DynamicLinearLayer
from xnas.spaces.OFA.utils import val2list, make_divisible
from xnas.spaces.BigNAS.dynamic_layers import DynamicMBConvLayer, DynamicConvLayer, DynamicShortcutLayer


class AttentiveNasStaticModel(nn.Module):

def __init__(self, first_conv, blocks, last_conv, classifier, resolution, use_v3_head=True):
super(AttentiveNasStaticModel, self).__init__()
self.first_conv = first_conv
self.blocks = nn.ModuleList(blocks)
self.last_conv = last_conv
self.classifier = classifier

self.resolution = resolution #input size
self.use_v3_head = use_v3_head

def forward(self, x):
# resize input to target resolution first
# Rule: transform images into different sizes
if x.size(-1) != self.resolution:
x = F.interpolate(x, size=self.resolution, mode='bicubic')

x = self.first_conv(x)
for block in self.blocks:
x = block(x)
x = self.last_conv(x)
if not self.use_v3_head:
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
x = torch.squeeze(x)
x = self.classifier(x)
return x


@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
for block in self.blocks:
_str += block.module_str + '\n'
#_str += self.last_conv.module_str + '\n'
_str += self.classifier.module_str
return _str

@property
def config(self):
return {
'name': AttentiveNasStaticModel.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
#'last_conv': self.last_conv.config,
'classifier': self.classifier.config,
'resolution': self.resolution
}


def weight_initialization(self):
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)

@staticmethod
def build_from_config(config):
raise NotImplementedError

def set_bn_param(self, momentum, eps):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
if momentum is not None:
m.momentum = float(momentum)
else:
m.momentum = None
m.eps = float(eps)
return

def get_bn_param(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
return {
'momentum': m.momentum,
'eps': m.eps,
}
return None

def reset_running_stats_for_calibration(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
m.training = True
m.momentum = None # cumulative moving average
m.reset_running_stats()


class AttentiveNasDynamicModel(nn.Module):

def __init__(self, supernet_cfg, n_classes=1000, bn_param=(0., 1e-5)):
super(AttentiveNasDynamicModel, self).__init__()

self.supernet_cfg = supernet_cfg
self.n_classes = n_classes
self.use_v3_head = getattr(self.supernet_cfg, 'use_v3_head', False)
self.stage_names = ['first_conv', 'mb1', 'mb2', 'mb3', 'mb4', 'mb5', 'mb6', 'mb7', 'last_conv']

self.width_list, self.depth_list, self.ks_list, self.expand_ratio_list = [], [], [], []
for name in self.stage_names:
block_cfg = getattr(self.supernet_cfg, name)
self.width_list.append(block_cfg.c)
if name.startswith('mb'):
self.depth_list.append(block_cfg.d)
self.ks_list.append(block_cfg.k)
self.expand_ratio_list.append(block_cfg.t)
self.resolution_list = self.supernet_cfg.resolutions

self.cfg_candidates = {
'resolution': self.resolution_list,
'width': self.width_list,
'depth': self.depth_list,
'kernel_size': self.ks_list,
'expand_ratio': self.expand_ratio_list
}

#first conv layer, including conv, bn, act
out_channel_list, act_func, stride = \
self.supernet_cfg.first_conv.c, self.supernet_cfg.first_conv.act_func, self.supernet_cfg.first_conv.s
self.first_conv = DynamicConvLayer(
in_channel_list=val2list(3), out_channel_list=out_channel_list,
kernel_size=3, stride=stride, act_func=act_func,
)

# inverted residual blocks
self.block_group_info = []
blocks = []
_block_index = 0
feature_dim = out_channel_list
for stage_id, key in enumerate(self.stage_names[1:-1]):
block_cfg = getattr(self.supernet_cfg, key)
width = block_cfg.c
n_block = max(block_cfg.d)
act_func = block_cfg.act_func
ks = block_cfg.k
expand_ratio_list = block_cfg.t
use_se = block_cfg.se

self.block_group_info.append([_block_index + i for i in range(n_block)])
_block_index += n_block

output_channel = width
for i in range(n_block):
stride = block_cfg.s if i == 0 else 1
if min(expand_ratio_list) >= 4:
expand_ratio_list = [_s for _s in expand_ratio_list if _s >= 4] if i == 0 else expand_ratio_list
mobile_inverted_conv = DynamicMBConvLayer(
in_channel_list=feature_dim,
out_channel_list=output_channel,
kernel_size_list=ks,
expand_ratio_list=expand_ratio_list,
stride=stride,
act_func=act_func,
use_se=use_se,
channels_per_group=getattr(self.supernet_cfg, 'channels_per_group', 1)
)
# Rule: add skip-connect, and use 2x2 AvgPool or 1x1 Conv for adaptation
shortcut = DynamicShortcutLayer(feature_dim, output_channel, reduction=stride)
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
feature_dim = output_channel
self.blocks = nn.ModuleList(blocks)

last_channel, act_func = self.supernet_cfg.last_conv.c, self.supernet_cfg.last_conv.act_func
if not self.use_v3_head:
self.last_conv = DynamicConvLayer(
in_channel_list=feature_dim, out_channel_list=last_channel,
kernel_size=1, act_func=act_func,
)
else:
expand_feature_dim = [f_dim * 6 for f_dim in feature_dim]
self.last_conv = nn.Sequential(OrderedDict([
('final_expand_layer', DynamicConvLayer(
feature_dim, expand_feature_dim, kernel_size=1, use_bn=True, act_func=act_func)
),
('pool', nn.AdaptiveAvgPool2d((1,1))),
('feature_mix_layer', DynamicConvLayer(
in_channel_list=expand_feature_dim, out_channel_list=last_channel,
kernel_size=1, act_func=act_func, use_bn=False,)
),
]))

#final conv layer
self.classifier = DynamicLinearLayer(
in_features_list=last_channel, out_features=n_classes, bias=True
)

# set bn param
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])

# runtime_depth
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]

self.zero_residual_block_bn_weights()

self.active_dropout_rate = 0
self.active_drop_connect_rate = 0
self.active_resolution = 224

# Rule: Initialize learnable coefficient \gamma=0
def zero_residual_block_bn_weights(self):
with torch.no_grad():
for m in self.modules():
if isinstance(m, ResidualBlock):
if isinstance(m.mobile_inverted_conv, DynamicMBConvLayer) and m.shortcut is not None:
m.mobile_inverted_conv.point_linear.bn.bn.weight.zero_()

@staticmethod
def name():
return 'AttentiveNasModel'

def forward(self, x):
# resize input to target resolution first
if x.size(-1) != self.active_resolution:
x = F.interpolate(x, size=self.active_resolution, mode='bicubic')

# first conv
x = self.first_conv(x)
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
x = self.blocks[idx](x)

x = self.last_conv(x)
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
x = torch.squeeze(x)

if self.active_dropout_rate > 0 and self.training:
x = F.dropout(x, p = self.active_dropout_rate)

x = self.classifier(x)
return x


@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
_str += self.blocks[0].module_str + '\n'

for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
_str += self.blocks[idx].module_str + '\n'
if not self.use_v3_head:
_str += self.last_conv.module_str + '\n'
else:
_str += self.last_conv.final_expand_layer.module_str + '\n'
_str += self.last_conv.feature_mix_layer.module_str + '\n'
_str += self.classifier.module_str + '\n'
return _str

@property
def config(self):
return {
'name': AttentiveNasDynamicModel.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
'last_conv': self.last_conv.config if not self.use_v3_head else None,
'final_expand_layer': self.last_conv.final_expand_layer if self.use_v3_head else None,
'feature_mix_layer': self.last_conv.feature_mix_layer if self.use_v3_head else None,
'classifier': self.classifier.config,
'resolution': self.active_resolution
}


@staticmethod
def build_from_config(config):
raise NotImplementedError

def get_parameters(self, keys=None, mode="include"):
if keys is None:
for name, param in self.named_parameters():
if param.requires_grad:
yield param
elif mode == "include":
for name, param in self.named_parameters():
flag = False
for key in keys:
if key in name:
flag = True
break
if flag and param.requires_grad:
yield param
elif mode == "exclude":
for name, param in self.named_parameters():
flag = True
for key in keys:
if key in name:
flag = False
break
if flag and param.requires_grad:
yield param
else:
raise ValueError("do not support: %s" % mode)

def set_bn_param(self, momentum, eps):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
if momentum is not None:
m.momentum = float(momentum)
else:
m.momentum = None
m.eps = float(eps)
return

def get_bn_param(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
return {
'momentum': m.momentum,
'eps': m.eps,
}
return None

""" set, sample and get active sub-networks """
def set_active_subnet(self, resolution=224, width=None, depth=None, kernel_size=None, expand_ratio=None, **kwargs):
assert len(depth) == len(kernel_size) == len(expand_ratio) == len(width) - 2
#set resolution
self.active_resolution = resolution

# first conv
self.first_conv.active_out_channel = width[0]

for stage_id, (c, k, e, d) in enumerate(zip(width[1:-1], kernel_size, expand_ratio, depth)):
start_idx, end_idx = min(self.block_group_info[stage_id]), max(self.block_group_info[stage_id])
for block_id in range(start_idx, start_idx+d):
block = self.blocks[block_id]
#block output channels
block.mobile_inverted_conv.active_out_channel = c
if block.shortcut is not None:
block.shortcut.active_out_channel = c

#dw kernel size
block.mobile_inverted_conv.active_kernel_size = k

#dw expansion ration
block.mobile_inverted_conv.active_expand_ratio = e

#IRBlocks repated times
for i, d in enumerate(depth):
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)

#last conv
if not self.use_v3_head:
self.last_conv.active_out_channel = width[-1]
else:
# default expansion ratio: 6
self.last_conv.final_expand_layer.active_out_channel = width[-2] * 6
self.last_conv.feature_mix_layer.active_out_channel = width[-1]

def get_active_subnet_settings(self):
r = self.active_resolution
width, depth, kernel_size, expand_ratio= [], [], [], []

#first conv
width.append(self.first_conv.active_out_channel)
for stage_id in range(len(self.block_group_info)):
start_idx = min(self.block_group_info[stage_id])
block = self.blocks[start_idx] #first block
width.append(block.mobile_inverted_conv.active_out_channel)
kernel_size.append(block.mobile_inverted_conv.active_kernel_size)
expand_ratio.append(block.mobile_inverted_conv.active_expand_ratio)
depth.append(self.runtime_depth[stage_id])
if not self.use_v3_head:
width.append(self.last_conv.active_out_channel)
else:
width.append(self.last_conv.feature_mix_layer.active_out_channel)

return {
'resolution': r,
'width': width,
'kernel_size': kernel_size,
'expand_ratio': expand_ratio,
'depth': depth,
}

def set_dropout_rate(self, dropout=0, drop_connect=0, drop_connect_only_last_two_stages=True):
self.active_dropout_rate = dropout
for idx, block in enumerate(self.blocks):
if drop_connect_only_last_two_stages:
if idx not in self.block_group_info[-1] + self.block_group_info[-2]:
continue
this_drop_connect_rate = drop_connect * float(idx) / len(self.blocks)
block.drop_connect_rate = this_drop_connect_rate


def sample_min_subnet(self):
return self._sample_active_subnet(min_net=True)


def sample_max_subnet(self):
return self._sample_active_subnet(max_net=True)

def sample_active_subnet(self, compute_flops=False):
cfg = self._sample_active_subnet(
False, False
)
if compute_flops:
cfg['flops'] = self.compute_active_subnet_flops()
return cfg

def sample_active_subnet_within_range(self, targeted_min_flops, targeted_max_flops):
while True:
cfg = self._sample_active_subnet()
cfg['flops'] = self.compute_active_subnet_flops()
if cfg['flops'] >= targeted_min_flops and cfg['flops'] <= targeted_max_flops:
return cfg

def _sample_active_subnet(self, min_net=False, max_net=False):

sample_cfg = lambda candidates, sample_min, sample_max: \
min(candidates) if sample_min else (max(candidates) if sample_max else random.choice(candidates))

cfg = {}
# sample a resolution
cfg['resolution'] = sample_cfg(self.cfg_candidates['resolution'], min_net, max_net)
for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
cfg[k] = []
for vv in self.cfg_candidates[k]:
cfg[k].append(sample_cfg(val2list(vv), min_net, max_net))

self.set_active_subnet(
cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio']
)
return cfg


def mutate_and_reset(self, cfg, prob=0.1, keep_resolution=False):
cfg = deepcopy(cfg)
pick_another = lambda x, candidates: x if len(candidates) == 1 else random.choice([v for v in candidates if v != x])
# sample a resolution
r = random.random()
if r < prob and not keep_resolution:
cfg['resolution'] = pick_another(cfg['resolution'], self.cfg_candidates['resolution'])

# sample channels, depth, kernel_size, expand_ratio
for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
for _i, _v in enumerate(cfg[k]):
r = random.random()
if r < prob:
cfg[k][_i] = pick_another(cfg[k][_i], val2list(self.cfg_candidates[k][_i]))

self.set_active_subnet(
cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio']
)
return cfg


def crossover_and_reset(self, cfg1, cfg2, p=0.5):
def _cross_helper(g1, g2, prob):
assert type(g1) == type(g2)
if isinstance(g1, int):
return g1 if random.random() < prob else g2
elif isinstance(g1, list):
return [v1 if random.random() < prob else v2 for v1, v2 in zip(g1, g2)]
else:
raise NotImplementedError

cfg = {}
cfg['resolution'] = cfg1['resolution'] if random.random() < p else cfg2['resolution']
for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
cfg[k] = _cross_helper(cfg1[k], cfg2[k], p)

self.set_active_subnet(
cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio']
)
return cfg


def get_active_subnet(self, preserve_weight=True):
with torch.no_grad():
first_conv = self.first_conv.get_active_subnet(3, preserve_weight)

blocks = []
input_channel = first_conv.out_channels
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append(ResidualBlock(
self.blocks[idx].mobile_inverted_conv.get_active_subnet(input_channel, preserve_weight),
self.blocks[idx].shortcut.get_active_subnet(input_channel, preserve_weight) if self.blocks[idx].shortcut is not None else None
))
input_channel = stage_blocks[-1].mobile_inverted_conv.out_channels
blocks += stage_blocks

if not self.use_v3_head:
last_conv = self.last_conv.get_active_subnet(input_channel, preserve_weight)
in_features = last_conv.out_channels
else:
final_expand_layer = self.last_conv.final_expand_layer.get_active_subnet(input_channel, preserve_weight)
feature_mix_layer = self.last_conv.feature_mix_layer.get_active_subnet(input_channel*6, preserve_weight)
in_features = feature_mix_layer.out_channels
last_conv = nn.Sequential(
final_expand_layer,
nn.AdaptiveAvgPool2d((1,1)),
feature_mix_layer
)

classifier = self.classifier.get_active_subnet(in_features, preserve_weight)

_subnet = AttentiveNasStaticModel(
first_conv, blocks, last_conv, classifier, self.active_resolution, use_v3_head=self.use_v3_head
)
_subnet.set_bn_param(**self.get_bn_param())
return _subnet


def compute_active_subnet_flops(self):

def count_conv(c_in, c_out, size_out, groups, k):
kernel_ops = k**2
output_elements = c_out * size_out**2
ops = c_in * output_elements * kernel_ops / groups
return ops

def count_linear(c_in, c_out):
return c_in * c_out

total_ops = 0

c_in = 3
size_out = self.active_resolution // self.first_conv.stride
c_out = self.first_conv.active_out_channel

total_ops += count_conv(c_in, c_out, size_out, 1, 3)
c_in = c_out

# mb blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
block = self.blocks[idx]
c_middle = make_divisible(round(c_in * block.mobile_inverted_conv.active_expand_ratio), 8)
# 1*1 conv
if block.mobile_inverted_conv.inverted_bottleneck is not None:
total_ops += count_conv(c_in, c_middle, size_out, 1, 1)
# dw conv
stride = 1 if idx > active_idx[0] else block.mobile_inverted_conv.stride
if size_out % stride == 0:
size_out = size_out // stride
else:
size_out = (size_out +1) // stride
total_ops += count_conv(c_middle, c_middle, size_out, c_middle, block.mobile_inverted_conv.active_kernel_size)
# 1*1 conv
c_out = block.mobile_inverted_conv.active_out_channel
total_ops += count_conv(c_middle, c_out, size_out, 1, 1)
#se
if block.mobile_inverted_conv.use_se:
num_mid = make_divisible(c_middle // block.mobile_inverted_conv.depth_conv.se.reduction, divisor=8)
total_ops += count_conv(c_middle, num_mid, 1, 1, 1) * 2
if block.shortcut and c_in != c_out:
total_ops += count_conv(c_in, c_out, size_out, 1, 1)
c_in = c_out

if not self.use_v3_head:
c_out = self.last_conv.active_out_channel
total_ops += count_conv(c_in, c_out, size_out, 1, 1)
else:
c_expand = self.last_conv.final_expand_layer.active_out_channel
c_out = self.last_conv.feature_mix_layer.active_out_channel
total_ops += count_conv(c_in, c_expand, size_out, 1, 1)
total_ops += count_conv(c_expand, c_out, 1, 1, 1)

# n_classes
total_ops += count_linear(c_out, self.n_classes)
return total_ops / 1e6


def load_weights_from_pretrained_models(self, checkpoint_path):
with open(checkpoint_path, 'rb') as f:
checkpoint = torch.load(f, map_location='cpu')
assert isinstance(checkpoint, dict)
pretrained_state_dicts = checkpoint['state_dict']
for k, v in self.state_dict().items():
name = 'module.' + k if not k.startswith('module') else k
v.copy_(pretrained_state_dicts[name])


def _AttentiveNAS_CNN():
from xnas.core.config import cfg
bn_momentum = cfg.ATTENTIVENAS.BN_MOMENTUM
bn_eps = cfg.ATTENTIVENAS.BN_EPS
return AttentiveNasDynamicModel(
cfg.ATTENTIVENAS.SUPERNET_CFG,
cfg.LOADER.NUM_CLASSES,
(bn_momentum, bn_eps),
)

def _infer_AttentiveNAS_CNN():
from xnas.core.config import cfg
bn_momentum = cfg.ATTENTIVENAS.BN_MOMENTUM
bn_eps = cfg.ATTENTIVENAS.BN_EPS
supernet = AttentiveNasDynamicModel(
cfg.ATTENTIVENAS.SUPERNET_CFG,
cfg.LOADER.NUM_CLASSES,
(bn_momentum, bn_eps),
)
# namespace changed: pareto_models.supernet_checkpoint_path
supernet.load_weights_from_pretrained_models(cfg.SEARCH.WEIGHT)
# namespace created: active_subnet.*
supernet.set_active_subnet(
resolution=cfg.ATTENTIVENAS.ACTIVE_SUBNET.RESOLUTION,
width = cfg.ATTENTIVENAS.ACTIVE_SUBNET.WIDTH,
depth = cfg.ATTENTIVENAS.ACTIVE_SUBNET.DEPTH,
kernel_size = cfg.ATTENTIVENAS.ACTIVE_SUBNET.KERNEL_SIZE,
expand_ratio = cfg.ATTENTIVENAS.ACTIVE_SUBNET.EXPAND_RATIO,
)
model = supernet.get_active_subnet()
# house-keeping stuff: may using different values with supernet
model.set_bn_param(momentum=bn_momentum, eps=bn_eps)
del supernet
return model

+ 653
- 0
xnas/spaces/BigNAS/cnn.py View File

@@ -0,0 +1,653 @@
# Implementation adapted from AttentiveNAS: https://github.com/facebookresearch/AttentiveNAS

import random
from copy import deepcopy
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from xnas.spaces.OFA.ops import ResidualBlock
from xnas.spaces.OFA.dynamic_ops import DynamicLinearLayer
from xnas.spaces.OFA.utils import val2list, make_divisible
from xnas.spaces.BigNAS.dynamic_layers import DynamicMBConvLayer, DynamicConvLayer, DynamicShortcutLayer


class BigNASStaticModel(nn.Module):

def __init__(self, first_conv, blocks, last_conv, classifier, resolution, use_v3_head=True):
super(BigNASStaticModel, self).__init__()
self.first_conv = first_conv
self.blocks = nn.ModuleList(blocks)
self.last_conv = last_conv
self.classifier = classifier

self.resolution = resolution #input size
self.use_v3_head = use_v3_head

def forward(self, x):
# resize input to target resolution first
# Rule: transform images into different sizes
if x.size(-1) != self.resolution:
x = F.interpolate(x, size=self.resolution, mode='bicubic')

x = self.first_conv(x)
for block in self.blocks:
x = block(x)
x = self.last_conv(x)
if not self.use_v3_head:
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
x = torch.squeeze(x)
x = self.classifier(x)
return x


@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
for block in self.blocks:
_str += block.module_str + '\n'
#_str += self.last_conv.module_str + '\n'
_str += self.classifier.module_str
return _str

@property
def config(self):
return {
'name': BigNASStaticModel.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
#'last_conv': self.last_conv.config,
'classifier': self.classifier.config,
'resolution': self.resolution
}


def weight_initialization(self):
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)

@staticmethod
def build_from_config(config):
raise NotImplementedError

def set_bn_param(self, momentum, eps):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
if momentum is not None:
m.momentum = float(momentum)
else:
m.momentum = None
m.eps = float(eps)
return

def get_bn_param(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
return {
'momentum': m.momentum,
'eps': m.eps,
}
return None

def reset_running_stats_for_calibration(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
m.training = True
m.momentum = None # cumulative moving average
m.reset_running_stats()


class BigNASDynamicModel(nn.Module):

def __init__(self, supernet_cfg, n_classes=1000, bn_param=(0., 1e-5)):
super(BigNASDynamicModel, self).__init__()

self.supernet_cfg = supernet_cfg
self.n_classes = n_classes
self.use_v3_head = getattr(self.supernet_cfg, 'use_v3_head', False)
self.stage_names = ['first_conv', 'mb1', 'mb2', 'mb3', 'mb4', 'mb5', 'mb6', 'mb7', 'last_conv']

self.width_list, self.depth_list, self.ks_list, self.expand_ratio_list = [], [], [], []
for name in self.stage_names:
block_cfg = getattr(self.supernet_cfg, name)
self.width_list.append(block_cfg.c)
if name.startswith('mb'):
self.depth_list.append(block_cfg.d)
self.ks_list.append(block_cfg.k)
self.expand_ratio_list.append(block_cfg.t)
self.resolution_list = self.supernet_cfg.resolutions

self.cfg_candidates = {
'resolution': self.resolution_list ,
'width': self.width_list,
'depth': self.depth_list,
'kernel_size': self.ks_list,
'expand_ratio': self.expand_ratio_list
}

#first conv layer, including conv, bn, act
out_channel_list, act_func, stride = \
self.supernet_cfg.first_conv.c, self.supernet_cfg.first_conv.act_func, self.supernet_cfg.first_conv.s
self.first_conv = DynamicConvLayer(
in_channel_list=val2list(3), out_channel_list=out_channel_list,
kernel_size=3, stride=stride, act_func=act_func,
)

# inverted residual blocks
self.block_group_info = []
blocks = []
_block_index = 0
feature_dim = out_channel_list
for stage_id, key in enumerate(self.stage_names[1:-1]):
block_cfg = getattr(self.supernet_cfg, key)
width = block_cfg.c
n_block = max(block_cfg.d)
act_func = block_cfg.act_func
ks = block_cfg.k
expand_ratio_list = block_cfg.t
use_se = block_cfg.se

self.block_group_info.append([_block_index + i for i in range(n_block)])
_block_index += n_block

output_channel = width
for i in range(n_block):
stride = block_cfg.s if i == 0 else 1
if min(expand_ratio_list) >= 4:
expand_ratio_list = [_s for _s in expand_ratio_list if _s >= 4] if i == 0 else expand_ratio_list
mobile_inverted_conv = DynamicMBConvLayer(
in_channel_list=feature_dim,
out_channel_list=output_channel,
kernel_size_list=ks,
expand_ratio_list=expand_ratio_list,
stride=stride,
act_func=act_func,
use_se=use_se,
channels_per_group=getattr(self.supernet_cfg, 'channels_per_group', 1)
)
# Rule: add skip-connect, and use 2x2 AvgPool or 1x1 Conv for adaptation
shortcut = DynamicShortcutLayer(feature_dim, output_channel, reduction=stride)
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
feature_dim = output_channel
self.blocks = nn.ModuleList(blocks)

last_channel, act_func = self.supernet_cfg.last_conv.c, self.supernet_cfg.last_conv.act_func
if not self.use_v3_head:
self.last_conv = DynamicConvLayer(
in_channel_list=feature_dim, out_channel_list=last_channel,
kernel_size=1, act_func=act_func,
)
else:
expand_feature_dim = [f_dim * 6 for f_dim in feature_dim]
self.last_conv = nn.Sequential(OrderedDict([
('final_expand_layer', DynamicConvLayer(
feature_dim, expand_feature_dim, kernel_size=1, use_bn=True, act_func=act_func)
),
('pool', nn.AdaptiveAvgPool2d((1,1))),
('feature_mix_layer', DynamicConvLayer(
in_channel_list=expand_feature_dim, out_channel_list=last_channel,
kernel_size=1, act_func=act_func, use_bn=False,)
),
]))

#final conv layer
self.classifier = DynamicLinearLayer(
in_features_list=last_channel, out_features=n_classes, bias=True
)

# set bn param
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])

# runtime_depth
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]

self.zero_residual_block_bn_weights()

self.active_dropout_rate = 0
self.active_drop_connect_rate = 0
self.active_resolution = 224

# Rule: Initialize learnable coefficient \gamma=0
def zero_residual_block_bn_weights(self):
with torch.no_grad():
for m in self.modules():
if isinstance(m, ResidualBlock):
if isinstance(m.mobile_inverted_conv, DynamicMBConvLayer) and m.shortcut is not None:
m.mobile_inverted_conv.point_linear.bn.bn.weight.zero_()

@staticmethod
def name():
return 'BigNASDynamicModel'

def forward(self, x):
# resize input to target resolution first
if x.size(-1) != self.active_resolution:
x = F.interpolate(x, size=self.active_resolution, mode='bicubic')

# first conv
x = self.first_conv(x)
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
x = self.blocks[idx](x)

x = self.last_conv(x)
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
x = torch.squeeze(x)

if self.active_dropout_rate > 0 and self.training:
x = F.dropout(x, p = self.active_dropout_rate)

x = self.classifier(x)
return x


@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
_str += self.blocks[0].module_str + '\n'

for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
_str += self.blocks[idx].module_str + '\n'
if not self.use_v3_head:
_str += self.last_conv.module_str + '\n'
else:
_str += self.last_conv.final_expand_layer.module_str + '\n'
_str += self.last_conv.feature_mix_layer.module_str + '\n'
_str += self.classifier.module_str + '\n'
return _str

@property
def config(self):
return {
'name': BigNASDynamicModel.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
'last_conv': self.last_conv.config if not self.use_v3_head else None,
'final_expand_layer': self.last_conv.final_expand_layer if self.use_v3_head else None,
'feature_mix_layer': self.last_conv.feature_mix_layer if self.use_v3_head else None,
'classifier': self.classifier.config,
'resolution': self.active_resolution
}


@staticmethod
def build_from_config(config):
raise NotImplementedError

def get_parameters(self, keys=None, mode="include"):
if keys is None:
for name, param in self.named_parameters():
if param.requires_grad:
yield param
elif mode == "include":
for name, param in self.named_parameters():
flag = False
for key in keys:
if key in name:
flag = True
break
if flag and param.requires_grad:
yield param
elif mode == "exclude":
for name, param in self.named_parameters():
flag = True
for key in keys:
if key in name:
flag = False
break
if flag and param.requires_grad:
yield param
else:
raise ValueError("do not support: %s" % mode)

def set_bn_param(self, momentum, eps):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
if momentum is not None:
m.momentum = float(momentum)
else:
m.momentum = None
m.eps = float(eps)
return

def get_bn_param(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
return {
'momentum': m.momentum,
'eps': m.eps,
}
return None

""" set, sample and get active sub-networks """
def set_active_subnet(self, resolution=224, width=None, depth=None, kernel_size=None, expand_ratio=None, **kwargs):
assert len(depth) == len(kernel_size) == len(expand_ratio) == len(width) - 2
#set resolution
self.active_resolution = resolution

# first conv
self.first_conv.active_out_channel = width[0]

for stage_id, (c, k, e, d) in enumerate(zip(width[1:-1], kernel_size, expand_ratio, depth)):
start_idx, end_idx = min(self.block_group_info[stage_id]), max(self.block_group_info[stage_id])
for block_id in range(start_idx, start_idx+d):
block = self.blocks[block_id]
#block output channels
block.mobile_inverted_conv.active_out_channel = c
if block.shortcut is not None:
block.shortcut.active_out_channel = c

#dw kernel size
block.mobile_inverted_conv.active_kernel_size = k

#dw expansion ration
block.mobile_inverted_conv.active_expand_ratio = e

#IRBlocks repated times
for i, d in enumerate(depth):
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)

#last conv
if not self.use_v3_head:
self.last_conv.active_out_channel = width[-1]
else:
# default expansion ratio: 6
self.last_conv.final_expand_layer.active_out_channel = width[-2] * 6
self.last_conv.feature_mix_layer.active_out_channel = width[-1]

def get_active_subnet_settings(self):
r = self.active_resolution
width, depth, kernel_size, expand_ratio= [], [], [], []

#first conv
width.append(self.first_conv.active_out_channel)
for stage_id in range(len(self.block_group_info)):
start_idx = min(self.block_group_info[stage_id])
block = self.blocks[start_idx] #first block
width.append(block.mobile_inverted_conv.active_out_channel)
kernel_size.append(block.mobile_inverted_conv.active_kernel_size)
expand_ratio.append(block.mobile_inverted_conv.active_expand_ratio)
depth.append(self.runtime_depth[stage_id])
if not self.use_v3_head:
width.append(self.last_conv.active_out_channel)
else:
width.append(self.last_conv.feature_mix_layer.active_out_channel)

return {
'resolution': r,
'width': width,
'kernel_size': kernel_size,
'expand_ratio': expand_ratio,
'depth': depth,
}

def set_dropout_rate(self, dropout=0, drop_connect=0, drop_connect_only_last_two_stages=True):
self.active_dropout_rate = dropout
for idx, block in enumerate(self.blocks):
if drop_connect_only_last_two_stages:
if idx not in self.block_group_info[-1] + self.block_group_info[-2]:
continue
this_drop_connect_rate = drop_connect * float(idx) / len(self.blocks)
block.drop_connect_rate = this_drop_connect_rate


def sample_min_subnet(self):
return self._sample_active_subnet(min_net=True)


def sample_max_subnet(self):
return self._sample_active_subnet(max_net=True)

def sample_active_subnet(self, compute_flops=False):
cfg = self._sample_active_subnet(
False, False
)
if compute_flops:
cfg['flops'] = self.compute_active_subnet_flops()
return cfg

def sample_active_subnet_within_range(self, targeted_min_flops, targeted_max_flops):
while True:
cfg = self._sample_active_subnet()
cfg['flops'] = self.compute_active_subnet_flops()
if cfg['flops'] >= targeted_min_flops and cfg['flops'] <= targeted_max_flops:
return cfg

def _sample_active_subnet(self, min_net=False, max_net=False):

sample_cfg = lambda candidates, sample_min, sample_max: \
min(candidates) if sample_min else (max(candidates) if sample_max else random.choice(candidates))

cfg = {}
# sample a resolution
cfg['resolution'] = sample_cfg(self.cfg_candidates['resolution'], min_net, max_net)
for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
cfg[k] = []
for vv in self.cfg_candidates[k]:
cfg[k].append(sample_cfg(val2list(vv), min_net, max_net))

self.set_active_subnet(
cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio']
)
return cfg


def mutate_and_reset(self, cfg, prob=0.1, keep_resolution=False):
cfg = deepcopy(cfg)
pick_another = lambda x, candidates: x if len(candidates) == 1 else random.choice([v for v in candidates if v != x])
# sample a resolution
r = random.random()
if r < prob and not keep_resolution:
cfg['resolution'] = pick_another(cfg['resolution'], self.cfg_candidates['resolution'])

# sample channels, depth, kernel_size, expand_ratio
for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
for _i, _v in enumerate(cfg[k]):
r = random.random()
if r < prob:
cfg[k][_i] = pick_another(cfg[k][_i], val2list(self.cfg_candidates[k][_i]))

self.set_active_subnet(
cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio']
)
return cfg


def crossover_and_reset(self, cfg1, cfg2, p=0.5):
def _cross_helper(g1, g2, prob):
assert type(g1) == type(g2)
if isinstance(g1, int):
return g1 if random.random() < prob else g2
elif isinstance(g1, list):
return [v1 if random.random() < prob else v2 for v1, v2 in zip(g1, g2)]
else:
raise NotImplementedError

cfg = {}
cfg['resolution'] = cfg1['resolution'] if random.random() < p else cfg2['resolution']
for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
cfg[k] = _cross_helper(cfg1[k], cfg2[k], p)

self.set_active_subnet(
cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio']
)
return cfg


def get_active_subnet(self, preserve_weight=True):
with torch.no_grad():
first_conv = self.first_conv.get_active_subnet(3, preserve_weight)

blocks = []
input_channel = first_conv.out_channels
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append(ResidualBlock(
self.blocks[idx].mobile_inverted_conv.get_active_subnet(input_channel, preserve_weight),
self.blocks[idx].shortcut.get_active_subnet(input_channel, preserve_weight) if self.blocks[idx].shortcut is not None else None
))
input_channel = stage_blocks[-1].mobile_inverted_conv.out_channels
blocks += stage_blocks

if not self.use_v3_head:
last_conv = self.last_conv.get_active_subnet(input_channel, preserve_weight)
in_features = last_conv.out_channels
else:
final_expand_layer = self.last_conv.final_expand_layer.get_active_subnet(input_channel, preserve_weight)
feature_mix_layer = self.last_conv.feature_mix_layer.get_active_subnet(input_channel*6, preserve_weight)
in_features = feature_mix_layer.out_channels
last_conv = nn.Sequential(
final_expand_layer,
nn.AdaptiveAvgPool2d((1,1)),
feature_mix_layer
)

classifier = self.classifier.get_active_subnet(in_features, preserve_weight)

_subnet = BigNASStaticModel(
first_conv, blocks, last_conv, classifier, self.active_resolution, use_v3_head=self.use_v3_head
)
_subnet.set_bn_param(**self.get_bn_param())
return _subnet


def compute_active_subnet_flops(self):

def count_conv(c_in, c_out, size_out, groups, k):
kernel_ops = k**2
output_elements = c_out * size_out**2
ops = c_in * output_elements * kernel_ops / groups
return ops

def count_linear(c_in, c_out):
return c_in * c_out

total_ops = 0

c_in = 3
size_out = self.active_resolution // self.first_conv.stride
c_out = self.first_conv.active_out_channel

total_ops += count_conv(c_in, c_out, size_out, 1, 3)
c_in = c_out

# mb blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
block = self.blocks[idx]
c_middle = make_divisible(round(c_in * block.mobile_inverted_conv.active_expand_ratio), 8)
# 1*1 conv
if block.mobile_inverted_conv.inverted_bottleneck is not None:
total_ops += count_conv(c_in, c_middle, size_out, 1, 1)
# dw conv
stride = 1 if idx > active_idx[0] else block.mobile_inverted_conv.stride
if size_out % stride == 0:
size_out = size_out // stride
else:
size_out = (size_out +1) // stride
total_ops += count_conv(c_middle, c_middle, size_out, c_middle, block.mobile_inverted_conv.active_kernel_size)
# 1*1 conv
c_out = block.mobile_inverted_conv.active_out_channel
total_ops += count_conv(c_middle, c_out, size_out, 1, 1)
#se
if block.mobile_inverted_conv.use_se:
num_mid = make_divisible(c_middle // block.mobile_inverted_conv.depth_conv.se.reduction, divisor=8)
total_ops += count_conv(c_middle, num_mid, 1, 1, 1) * 2
if block.shortcut and c_in != c_out:
total_ops += count_conv(c_in, c_out, size_out, 1, 1)
c_in = c_out

if not self.use_v3_head:
c_out = self.last_conv.active_out_channel
total_ops += count_conv(c_in, c_out, size_out, 1, 1)
else:
c_expand = self.last_conv.final_expand_layer.active_out_channel
c_out = self.last_conv.feature_mix_layer.active_out_channel
total_ops += count_conv(c_in, c_expand, size_out, 1, 1)
total_ops += count_conv(c_expand, c_out, 1, 1, 1)

# n_classes
total_ops += count_linear(c_out, self.n_classes)
return total_ops / 1e6


def load_weights_from_pretrained_models(self, checkpoint_path):
with open(checkpoint_path, 'rb') as f:
checkpoint = torch.load(f, map_location='cpu')
assert isinstance(checkpoint, dict)
pretrained_state_dicts = checkpoint['model_state']
for k, v in self.state_dict().items():
# name = 'module.' + k if not k.startswith('module') else k
name = k
v.copy_(pretrained_state_dicts[name])


def _BigNAS_CNN():
from xnas.core.config import cfg
bn_momentum = cfg.BIGNAS.BN_MOMENTUM
bn_eps = cfg.BIGNAS.BN_EPS
return BigNASDynamicModel(
cfg.BIGNAS.SUPERNET_CFG,
cfg.LOADER.NUM_CLASSES,
(bn_momentum, bn_eps),
)

def _infer_BigNAS_CNN():
from xnas.core.config import cfg
bn_momentum = cfg.BIGNAS.BN_MOMENTUM
bn_eps = cfg.BIGNAS.BN_EPS
supernet = BigNASDynamicModel(
cfg.BIGNAS.SUPERNET_CFG,
cfg.LOADER.NUM_CLASSES,
(bn_momentum, bn_eps),
)
# namespace changed: pareto_models.supernet_checkpoint_path
supernet.load_weights_from_pretrained_models(cfg.SEARCH.WEIGHT)
# namespace created: active_subnet.*
supernet.set_active_subnet(
resolution=cfg.BIGNAS.ACTIVE_SUBNET.RESOLUTION,
width = cfg.BIGNAS.ACTIVE_SUBNET.WIDTH,
depth = cfg.BIGNAS.ACTIVE_SUBNET.DEPTH,
kernel_size = cfg.BIGNAS.ACTIVE_SUBNET.KERNEL_SIZE,
expand_ratio = cfg.BIGNAS.ACTIVE_SUBNET.EXPAND_RATIO,
)
model = supernet.get_active_subnet()
# house-keeping stuff: may using different values with supernet
model.set_bn_param(momentum=bn_momentum, eps=bn_eps)
del supernet
return model

+ 331
- 0
xnas/spaces/BigNAS/dynamic_layers.py View File

@@ -0,0 +1,331 @@
from collections import OrderedDict

import torch.nn as nn
import torch.nn.functional as F


from xnas.spaces.OFA.utils import val2list
from xnas.spaces.OFA.ops import SEModule, ConvLayer, ShortcutLayer, build_activation, make_divisible
from xnas.spaces.OFA.dynamic_ops import DynamicConv2d, DynamicSE, copy_bn
from xnas.spaces.BigNAS.ops import MBConvLayer
from xnas.spaces.BigNAS.dynamic_ops import DynamicSeparableConv2d, DynamicBatchNorm2d


class DynamicMBConvLayer(nn.Module):
def __init__(self, in_channel_list, out_channel_list,
kernel_size_list=3, expand_ratio_list=6, stride=1, act_func='relu6', use_se=False, channels_per_group=1):
super(DynamicMBConvLayer, self).__init__()
self.in_channel_list = val2list(in_channel_list)
self.out_channel_list = val2list(out_channel_list)
self.kernel_size_list = val2list(kernel_size_list, 1)
self.expand_ratio_list = val2list(expand_ratio_list, 1)
self.stride = stride
self.act_func = act_func
self.use_se = use_se
self.channels_per_group = channels_per_group
# build modules
max_middle_channel = round(max(self.in_channel_list) * max(self.expand_ratio_list))
if max(self.expand_ratio_list) == 1:
self.inverted_bottleneck = None
else:
self.inverted_bottleneck = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max(self.in_channel_list), max_middle_channel)),
('bn', DynamicBatchNorm2d(max_middle_channel)),
('act', build_activation(self.act_func, inplace=True)),
]))
self.depth_conv = nn.Sequential(OrderedDict([
('conv', DynamicSeparableConv2d(max_middle_channel, self.kernel_size_list, stride=self.stride, channels_per_group=self.channels_per_group)),
('bn', DynamicBatchNorm2d(max_middle_channel)),
('act', build_activation(self.act_func, inplace=True))
]))
if self.use_se:
self.depth_conv.add_module('se', DynamicSE(max_middle_channel))
self.point_linear = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max_middle_channel, max(self.out_channel_list))),
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
]))
self.active_kernel_size = max(self.kernel_size_list)
self.active_expand_ratio = max(self.expand_ratio_list)
self.active_out_channel = max(self.out_channel_list)
def forward(self, x):
in_channel = x.size(1)
if self.inverted_bottleneck is not None:
self.inverted_bottleneck.conv.active_out_channel = \
make_divisible(round(in_channel * self.active_expand_ratio), 8)

self.depth_conv.conv.active_kernel_size = self.active_kernel_size
self.point_linear.conv.active_out_channel = self.active_out_channel
if self.inverted_bottleneck is not None:
x = self.inverted_bottleneck(x)
x = self.depth_conv(x)
x = self.point_linear(x)
return x
@property
def module_str(self):
if self.use_se:
return 'SE(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)
else:
return '(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)
@property
def config(self):
return {
'name': DynamicMBConvLayer.__name__,
'in_channel_list': self.in_channel_list,
'out_channel_list': self.out_channel_list,
'kernel_size_list': self.kernel_size_list,
'expand_ratio_list': self.expand_ratio_list,
'stride': self.stride,
'act_func': self.act_func,
'use_se': self.use_se,
'channels_per_group': self.channels_per_group,
}
@staticmethod
def build_from_config(config):
return DynamicMBConvLayer(**config)

############################################################################################

def get_active_subnet(self, in_channel, preserve_weight=True):
middle_channel = make_divisible(round(in_channel * self.active_expand_ratio), 8)
channels_per_group = self.depth_conv.conv.channels_per_group

# build the new layer
sub_layer = MBConvLayer(
in_channel, self.active_out_channel, self.active_kernel_size, self.stride, self.active_expand_ratio,
act_func=self.act_func, mid_channels=middle_channel, use_se=self.use_se, channels_per_group=channels_per_group
)
sub_layer = sub_layer.to(self.parameters().__next__().device)

if not preserve_weight:
return sub_layer

# copy weight from current layer
if sub_layer.inverted_bottleneck is not None:
sub_layer.inverted_bottleneck.conv.weight.data.copy_(
self.inverted_bottleneck.conv.conv.weight.data[:middle_channel, :in_channel, :, :]
)
copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn)

sub_layer.depth_conv.conv.weight.data.copy_(
self.depth_conv.conv.get_active_filter(middle_channel, self.active_kernel_size).data
)
copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn)

if self.use_se:
se_mid = make_divisible(middle_channel // SEModule.REDUCTION, divisor=8)
sub_layer.depth_conv.se.fc.reduce.weight.data.copy_(
self.depth_conv.se.fc.reduce.weight.data[:se_mid, :middle_channel, :, :]
)
sub_layer.depth_conv.se.fc.reduce.bias.data.copy_(self.depth_conv.se.fc.reduce.bias.data[:se_mid])

sub_layer.depth_conv.se.fc.expand.weight.data.copy_(
self.depth_conv.se.fc.expand.weight.data[:middle_channel, :se_mid, :, :]
)
sub_layer.depth_conv.se.fc.expand.bias.data.copy_(self.depth_conv.se.fc.expand.bias.data[:middle_channel])

sub_layer.point_linear.conv.weight.data.copy_(
self.point_linear.conv.conv.weight.data[:self.active_out_channel, :middle_channel, :, :]
)
copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn)

return sub_layer

def re_organize_middle_weights(self, expand_ratio_stage=0):
# importance = torch.sum(torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3))
# if expand_ratio_stage > 0:
# sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
# sorted_expand_list.sort(reverse=True)
# target_width = sorted_expand_list[expand_ratio_stage]
# target_width = round(max(self.in_channel_list) * target_width)
# importance[target_width:] = torch.arange(0, target_width - importance.size(0), -1)
# sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
# self.point_linear.conv.conv.weight.data = torch.index_select(
# self.point_linear.conv.conv.weight.data, 1, sorted_idx
# )
# adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx)
# self.depth_conv.conv.conv.weight.data = torch.index_select(
# self.depth_conv.conv.conv.weight.data, 0, sorted_idx
# )

# if self.use_se:
# # se expand: output dim 0 reorganize
# se_expand = self.depth_conv.se.fc.expand
# se_expand.weight.data = torch.index_select(se_expand.weight.data, 0, sorted_idx)
# se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx)
# # se reduce: input dim 1 reorganize
# se_reduce = self.depth_conv.se.fc.reduce
# se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 1, sorted_idx)
# # middle weight reorganize
# se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3))
# se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True)

# se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx)
# se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx)
# se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx)
# # TODO if inverted_bottleneck is None, the previous layer should be reorganized accordingly
# if self.inverted_bottleneck is not None:
# adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx)
# self.inverted_bottleneck.conv.conv.weight.data = torch.index_select(
# self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx
# )
# return None
# else:
# return sorted_idx
raise NotImplementedError


class DynamicConvLayer(nn.Module):
def __init__(self, in_channel_list, out_channel_list, kernel_size=3, stride=1, dilation=1,
use_bn=True, act_func='relu6'):
super(DynamicConvLayer, self).__init__()
self.in_channel_list = val2list(in_channel_list)
self.out_channel_list = val2list(out_channel_list)
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.use_bn = use_bn
self.act_func = act_func
self.conv = DynamicConv2d(
max_in_channels=max(self.in_channel_list), max_out_channels=max(self.out_channel_list),
kernel_size=self.kernel_size, stride=self.stride, dilation=self.dilation,
)
if self.use_bn:
self.bn = DynamicBatchNorm2d(max(self.out_channel_list))

if self.act_func is not None:
self.act = build_activation(self.act_func, inplace=True)
self.active_out_channel = max(self.out_channel_list)
def forward(self, x):
self.conv.active_out_channel = self.active_out_channel
x = self.conv(x)
if self.use_bn:
x = self.bn(x)
if self.act_func is not None:
x = self.act(x)
return x
@property
def module_str(self):
return 'DyConv(O%d, K%d, S%d)' % (self.active_out_channel, self.kernel_size, self.stride)
@property
def config(self):
return {
'name': DynamicConvLayer.__name__,
'in_channel_list': self.in_channel_list,
'out_channel_list': self.out_channel_list,
'kernel_size': self.kernel_size,
'stride': self.stride,
'dilation': self.dilation,
'use_bn': self.use_bn,
'act_func': self.act_func,
}
@staticmethod
def build_from_config(config):
return DynamicConvLayer(**config)
def get_active_subnet(self, in_channel, preserve_weight=True):
sub_layer = ConvLayer(
in_channel, self.active_out_channel, self.kernel_size, self.stride, self.dilation,
use_bn=self.use_bn, act_func=self.act_func
)
sub_layer = sub_layer.to(self.parameters().__next__().device)
if not preserve_weight:
return sub_layer
sub_layer.conv.weight.data.copy_(self.conv.conv.weight.data[:self.active_out_channel, :in_channel, :, :])
if self.use_bn:
copy_bn(sub_layer.bn, self.bn.bn)
return sub_layer


class DynamicShortcutLayer(nn.Module):
def __init__(self, in_channel_list, out_channel_list, reduction=1):
super(DynamicShortcutLayer, self).__init__()
self.in_channel_list = val2list(in_channel_list)
self.out_channel_list = val2list(out_channel_list)
self.reduction = reduction
self.conv = DynamicConv2d(
max_in_channels=max(self.in_channel_list), max_out_channels=max(self.out_channel_list),
kernel_size=1, stride=1,
)

self.active_out_channel = max(self.out_channel_list)
def forward(self, x):
in_channel = x.size(1)

#identity mapping
if in_channel == self.active_out_channel and self.reduction == 1:
return x
#average pooling, if size doesn't match
if self.reduction > 1:
padding = 0 if x.size(-1) % 2 == 0 else 1
x = F.avg_pool2d(x, self.reduction, padding=padding)

#1*1 conv, if #channels doesn't match
if in_channel != self.active_out_channel:
self.conv.active_out_channel = self.active_out_channel
x = self.conv(x)
return x
@property
def module_str(self):
return 'DyShortcut(O%d, R%d)' % (self.active_out_channel, self.reduction)
@property
def config(self):
return {
'name': DynamicShortcutLayer.__name__,
'in_channel_list': self.in_channel_list,
'out_channel_list': self.out_channel_list,
'reduction': self.reduction,
}
@staticmethod
def build_from_config(config):
return DynamicShortcutLayer(**config)
def get_active_subnet(self, in_channel, preserve_weight=True):
sub_layer = ShortcutLayer(
in_channel, self.active_out_channel, self.reduction
)
sub_layer = sub_layer.to(self.parameters().__next__().device)
if not preserve_weight:
return sub_layer
sub_layer.conv.weight.data.copy_(self.conv.conv.weight.data[:self.active_out_channel, :in_channel, :, :])
return sub_layer


+ 181
- 0
xnas/spaces/BigNAS/dynamic_ops.py View File

@@ -0,0 +1,181 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.autograd.function import Function

from xnas.spaces.OFA.ops import get_same_padding
from xnas.spaces.OFA.dynamic_ops import sub_filter_start_end



class DynamicSeparableConv2d(nn.Module):
# KERNEL_TRANSFORM_MODE = None # None or 1
def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1, channels_per_group=1):
super(DynamicSeparableConv2d, self).__init__()
self.max_in_channels = max_in_channels
self.channels_per_group = channels_per_group
assert self.max_in_channels % self.channels_per_group == 0
self.kernel_size_list = kernel_size_list
self.stride = stride
self.dilation = dilation
self.conv = nn.Conv2d(
self.max_in_channels, self.max_in_channels, max(self.kernel_size_list), self.stride,
groups=self.max_in_channels // self.channels_per_group, bias=False,
)
self._ks_set = list(set(self.kernel_size_list))
self._ks_set.sort() # e.g., [3, 5, 7]
# if self.KERNEL_TRANSFORM_MODE is not None:
# # register scaling parameters
# # 7to5_matrix, 5to3_matrix
# scale_params = {}
# for i in range(len(self._ks_set) - 1):
# ks_small = self._ks_set[i]
# ks_larger = self._ks_set[i + 1]
# param_name = '%dto%d' % (ks_larger, ks_small)
# scale_params['%s_matrix' % param_name] = Parameter(torch.eye(ks_small ** 2))
# for name, param in scale_params.items():
# self.register_parameter(name, param)

self.active_kernel_size = max(self.kernel_size_list)
def get_active_filter(self, in_channel, kernel_size):
out_channel = in_channel
max_kernel_size = max(self.kernel_size_list)
start, end = sub_filter_start_end(max_kernel_size, kernel_size)
filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end]
# if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size:
# start_filter = self.conv.weight[:out_channel, :in_channel, :, :] # start with max kernel
# for i in range(len(self._ks_set) - 1, 0, -1):
# src_ks = self._ks_set[i]
# if src_ks <= kernel_size:
# break
# target_ks = self._ks_set[i - 1]
# start, end = sub_filter_start_end(src_ks, target_ks)
# _input_filter = start_filter[:, :, start:end, start:end]
# _input_filter = _input_filter.contiguous()
# _input_filter = _input_filter.view(_input_filter.size(0), _input_filter.size(1), -1)
# _input_filter = _input_filter.view(-1, _input_filter.size(2))
# _input_filter = F.linear(
# _input_filter, self.__getattr__('%dto%d_matrix' % (src_ks, target_ks)),
# )
# _input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks ** 2)
# _input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks, target_ks)
# start_filter = _input_filter
# filters = start_filter
return filters
def forward(self, x, kernel_size=None):
if kernel_size is None:
kernel_size = self.active_kernel_size
in_channel = x.size(1)
assert in_channel % self.channels_per_group == 0
filters = self.get_active_filter(in_channel, kernel_size).contiguous()
padding = get_same_padding(kernel_size)
y = F.conv2d(
x, filters, None, self.stride, padding, self.dilation, in_channel // self.channels_per_group
)
return y


class AllReduce(Function):
@staticmethod
def forward(ctx, input):
input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())]
# Use allgather instead of allreduce since I don't trust in-place operations ..
dist.all_gather(input_list, input, async_op=False)
inputs = torch.stack(input_list, dim=0)
return torch.sum(inputs, dim=0)

@staticmethod
def backward(ctx, grad_output):
dist.all_reduce(grad_output, async_op=False)
return grad_output


class DynamicBatchNorm2d(nn.Module):
'''
1. doesn't acculate bn statistics, (momentum=0.)
2. calculate BN statistics of all subnets after training
3. bn weights are shared
https://arxiv.org/abs/1903.05134
https://detectron2.readthedocs.io/_modules/detectron2/layers/batch_norm.html
'''
#SET_RUNNING_STATISTICS = False
def __init__(self, max_feature_dim):
super(DynamicBatchNorm2d, self).__init__()
self.max_feature_dim = max_feature_dim
self.bn = nn.BatchNorm2d(self.max_feature_dim)

# self.exponential_average_factor = 0 # doesn't acculate bn stats
self.need_sync = False # sync-batchnormalization, suggested to use in bignas

# reserved to tracking the performance of the largest and smallest network
self.bn_tracking = nn.ModuleList(
[
nn.BatchNorm2d(self.max_feature_dim, affine=False),
nn.BatchNorm2d(self.max_feature_dim, affine=False)
]
)

def forward(self, x):
feature_dim = x.size(1)
if not self.training:
raise ValueError('DynamicBN only supports training')
bn = self.bn
# need_sync
if not self.need_sync:
return F.batch_norm(
x, bn.running_mean[:feature_dim], bn.running_var[:feature_dim], bn.weight[:feature_dim],
bn.bias[:feature_dim], bn.training or not bn.track_running_stats,
bn.momentum, bn.eps,
)
else:
assert dist.get_world_size() > 1, 'SyncBatchNorm requires >1 world size'
B, C = x.shape[0], x.shape[1]
mean = torch.mean(x, dim=[0, 2, 3])
meansqr = torch.mean(x * x, dim=[0, 2, 3])
assert B > 0, 'does not support zero batch size'
vec = torch.cat([mean, meansqr], dim=0)
vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
mean, meansqr = torch.split(vec, C)

var = meansqr - mean * mean
invstd = torch.rsqrt(var + bn.eps)
scale = bn.weight[:feature_dim] * invstd
bias = bn.bias[:feature_dim] - mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return x * scale + bias


#if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS:
# return bn(x)
#else:
# exponential_average_factor = 0.0

# if bn.training and bn.track_running_stats:
# # TODO: if statement only here to tell the jit to skip emitting this when it is None
# if bn.num_batches_tracked is not None:
# bn.num_batches_tracked += 1
# if bn.momentum is None: # use cumulative moving average
# exponential_average_factor = 1.0 / float(bn.num_batches_tracked)
# else: # use exponential moving average
# exponential_average_factor = bn.momentum
# return F.batch_norm(
# x, bn.running_mean[:feature_dim], bn.running_var[:feature_dim], bn.weight[:feature_dim],
# bn.bias[:feature_dim], bn.training or not bn.track_running_stats,
# exponential_average_factor, bn.eps,
# )


+ 133
- 0
xnas/spaces/BigNAS/ops.py View File

@@ -0,0 +1,133 @@
import torch
import torch.nn as nn

from collections import OrderedDict
from xnas.spaces.OFA.ops import SEModule, build_activation
from xnas.spaces.OFA.utils import (
get_same_padding,
)

class MBConvLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
expand_ratio=6,
mid_channels=None,
act_func="relu6",
use_se=False,
channels_per_group=1,
):
super(MBConvLayer, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels

self.kernel_size = kernel_size
self.stride = stride
self.expand_ratio = expand_ratio
self.mid_channels = mid_channels
self.act_func = act_func
self.use_se = use_se
self.channels_per_group = channels_per_group

if self.mid_channels is None:
feature_dim = round(self.in_channels * self.expand_ratio)
else:
feature_dim = self.mid_channels

if self.expand_ratio == 1:
self.inverted_bottleneck = None
else:
self.inverted_bottleneck = nn.Sequential(OrderedDict([
("conv", nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]))

assert feature_dim % self.channels_per_group == 0
active_groups = feature_dim // self.channels_per_group
pad = get_same_padding(self.kernel_size)
# assert feature_dim % self.groups == 0
# active_groups = feature_dim // self.groups
depth_conv_modules = [
(
"conv",
nn.Conv2d(
feature_dim,
feature_dim,
kernel_size,
stride,
pad,
groups=active_groups,
bias=False,
),
),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]
if self.use_se:
depth_conv_modules.append(("se", SEModule(feature_dim)))
self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules))

self.point_linear = nn.Sequential(
OrderedDict(
[
("conv", nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
("bn", nn.BatchNorm2d(out_channels)),
]
)
)

def forward(self, x):
if self.inverted_bottleneck:
x = self.inverted_bottleneck(x)
x = self.depth_conv(x)
x = self.point_linear(x)
return x

@property
def module_str(self):
if self.mid_channels is None:
expand_ratio = self.expand_ratio
else:
expand_ratio = self.mid_channels // self.in_channels
layer_str = "%dx%d_MBConv%d_%s" % (
self.kernel_size,
self.kernel_size,
expand_ratio,
self.act_func.upper(),
)
if self.use_se:
layer_str = "SE_" + layer_str
layer_str += "_O%d" % self.out_channels
if self.groups is not None:
layer_str += "_G%d" % self.groups
if isinstance(self.point_linear.bn, nn.GroupNorm):
layer_str += "_GN%d" % self.point_linear.bn.num_groups
elif isinstance(self.point_linear.bn, nn.BatchNorm2d):
layer_str += "_BN"

return layer_str

@property
def config(self):
return {
"name": MBConvLayer.__name__,
"in_channels": self.in_channels,
"out_channels": self.out_channels,
"kernel_size": self.kernel_size,
"stride": self.stride,
"expand_ratio": self.expand_ratio,
"mid_channels": self.mid_channels,
"act_func": self.act_func,
"use_se": self.use_se,
"groups": self.groups,
}

@staticmethod
def build_from_config(config):
return MBConvLayer(**config)

+ 134
- 0
xnas/spaces/BigNAS/utils.py View File

@@ -0,0 +1,134 @@
# Implementation adapted from attentiveNAS - https://github.com/facebookresearch/AttentiveNAS

import torch
import torch.nn as nn
import copy
import math

multiply_adds = 1


def count_convNd(m, _, y):
cin = m.in_channels

kernel_ops = m.weight.size()[2] * m.weight.size()[3]
ops_per_element = kernel_ops
output_elements = y.nelement()

# cout x oW x oH
total_ops = cin * output_elements * ops_per_element // m.groups
m.total_ops = torch.Tensor([int(total_ops)])


def count_linear(m, _, __):
total_ops = m.in_features * m.out_features

m.total_ops = torch.Tensor([int(total_ops)])


register_hooks = {
nn.Conv1d: count_convNd,
nn.Conv2d: count_convNd,
nn.Conv3d: count_convNd,
######################################
nn.Linear: count_linear,
######################################
nn.Dropout: None,
nn.Dropout2d: None,
nn.Dropout3d: None,
nn.BatchNorm2d: None,
}


def profile(model, input_size=(1, 3, 224, 224), custom_ops=None):
handler_collection = []
custom_ops = {} if custom_ops is None else custom_ops

def add_hooks(m_):
if len(list(m_.children())) > 0:
return

m_.register_buffer('total_ops', torch.zeros(1))
m_.register_buffer('total_params', torch.zeros(1))

for p in m_.parameters():
m_.total_params += torch.Tensor([p.numel()])

m_type = type(m_)
fn = None

if m_type in custom_ops:
fn = custom_ops[m_type]
elif m_type in register_hooks:
fn = register_hooks[m_type]
else:
# print("Not implemented for ", m_)
pass

if fn is not None:
# print("Register FLOP counter for module %s" % str(m_))
_handler = m_.register_forward_hook(fn)
handler_collection.append(_handler)

original_device = model.parameters().__next__().device
training = model.training

model.eval()
model.apply(add_hooks)

x = torch.zeros(input_size).to(original_device)
with torch.no_grad():
model(x)

total_ops = 0
total_params = 0
for m in model.modules():
if len(list(m.children())) > 0: # skip for non-leaf module
continue
total_ops += m.total_ops
total_params += m.total_params

total_ops = total_ops.item()
total_params = total_params.item()

model.train(training)
model.to(original_device)

for handler in handler_collection:
handler.remove()

return total_ops, total_params


def count_net_flops_and_params(net, data_shape=(1, 3, 224, 224)):
if isinstance(net, nn.DataParallel):
net = net.module

net = copy.deepcopy(net)
flop, nparams = profile(net, data_shape)
return flop /1e6, nparams /1e6


def init_model(self, model_init="he_fout"):
""" Conv2d, BatchNorm2d, BatchNorm1d, Linear, """
for m in self.modules():
if isinstance(m, nn.Conv2d):
if model_init == 'he_fout':
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif model_init == 'he_fin':
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
else:
raise NotImplementedError
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
if m.affine:
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
stdv = 1. / math.sqrt(m.weight.size(1))
m.weight.data.uniform_(-stdv, stdv)
if m.bias is not None:
m.bias.data.zero_()

+ 14
- 10
xnas/spaces/OFA/ProxylessNet/cnn.py View File

@@ -140,14 +140,14 @@ class MobileNetV2(ProxylessNASNet):
width_mult=1.0,
bn_param=(0.1, 1e-3),
dropout_rate=0.2,
ks=None,
expand_ratio=None,
ks=None, # a list only include {3, 5, 7}
expand_ratio=None, # in proxyless space only 3 or 6
depth_param=None,
stage_width_list=None,
):

ks = 3 if ks is None else ks
expand_ratio = 6 if expand_ratio is None else expand_ratio
expand_ratio = [6]*6 if expand_ratio is None else expand_ratio

input_channel = 32
last_channel = 1280
@@ -162,12 +162,12 @@ class MobileNetV2(ProxylessNASNet):
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[expand_ratio, 24, 2, 2],
[expand_ratio, 32, 3, 2],
[expand_ratio, 64, 4, 2],
[expand_ratio, 96, 3, 1],
[expand_ratio, 160, 3, 2],
[expand_ratio, 320, 1, 1],
[None, 24, 2, 2],
[None, 32, 3, 2],
[None, 64, 4, 2],
[None, 96, 3, 1],
[None, 160, 3, 2],
[None, 320, 1, 1],
]

if depth_param is not None:
@@ -179,6 +179,10 @@ class MobileNetV2(ProxylessNASNet):
for i in range(len(inverted_residual_setting)):
inverted_residual_setting[i][1] = stage_width_list[i]

if expand_ratio is not None:
for i in range(len(inverted_residual_setting)):
inverted_residual_setting[i][0] = expand_ratio[i]

ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1)
_pt = 0

@@ -201,7 +205,7 @@ class MobileNetV2(ProxylessNASNet):
stride = s
else:
stride = 1
if t == 1:
if t == 1: # only used for first block
kernel_size = 3
else:
kernel_size = ks[_pt]


+ 8
- 26
xnas/spaces/OFA/dynamic_ops.py View File

@@ -488,36 +488,18 @@ class DynamicMBConvLayer(nn.Module):
self.use_se = use_se

# build modules
max_middle_channel = make_divisible(
round(max(self.in_channel_list) * max(self.expand_ratio_list)))
max_middle_channel = make_divisible(round(max(self.in_channel_list) * max(self.expand_ratio_list)))
if max(self.expand_ratio_list) == 1:
self.inverted_bottleneck = None
else:
self.inverted_bottleneck = nn.Sequential(
OrderedDict(
[
(
"conv",
DynamicConv2d(
max(self.in_channel_list), max_middle_channel
),
),
("bn", DynamicBatchNorm2d(max_middle_channel)),
("act", build_activation(self.act_func)),
]
)
)
self.inverted_bottleneck = nn.Sequential(OrderedDict([
("conv", DynamicConv2d(max(self.in_channel_list), max_middle_channel)),
("bn", DynamicBatchNorm2d(max_middle_channel)),
("act", build_activation(self.act_func)),
]))

self.depth_conv = nn.Sequential(
OrderedDict(
[
(
"conv",
DynamicSeparableConv2d(
max_middle_channel, self.kernel_size_list, self.stride,
kernel_trans=kernel_trans
),
),
self.depth_conv = nn.Sequential(OrderedDict([
("conv", DynamicSeparableConv2d(max_middle_channel, self.kernel_size_list, stride=self.stride, kernel_trans=kernel_trans)),
("bn", DynamicBatchNorm2d(max_middle_channel)),
("act", build_activation(self.act_func)),
]


+ 88
- 18
xnas/spaces/OFA/ops.py View File

@@ -7,6 +7,7 @@ from xnas.spaces.OFA.utils import (
min_divisible_value,
get_same_padding,
make_divisible,
drop_connect,
)


@@ -46,12 +47,32 @@ def build_activation(act_func, inplace=True):
return Hswish(inplace=inplace)
elif act_func == "h_sigmoid":
return Hsigmoid(inplace=inplace)
elif act_func == 'swish':
return MemoryEfficientSwish()
elif act_func is None or act_func == "none":
return None
else:
raise ValueError("do not support: %s" % act_func)


class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result

@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_tensors[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

class MemoryEfficientSwish(nn.Module):
def forward(self, x):
return SwishImplementation.apply(x)

class Hswish(nn.Module):
def __init__(self, inplace=True):
super(Hswish, self).__init__()
@@ -637,27 +658,20 @@ class MBConvLayer(nn.Module):
if self.expand_ratio == 1:
self.inverted_bottleneck = None
else:
self.inverted_bottleneck = nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(
self.in_channels, feature_dim, 1, 1, 0, bias=False
),
),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]
)
)
self.inverted_bottleneck = nn.Sequential(OrderedDict([
("conv", nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]))

pad = get_same_padding(self.kernel_size)
groups = (
active_groups = (
feature_dim
if self.groups is None
else min_divisible_value(feature_dim, self.groups)
)
# assert feature_dim % self.groups == 0
# active_groups = feature_dim // self.groups
depth_conv_modules = [
(
"conv",
@@ -667,7 +681,7 @@ class MBConvLayer(nn.Module):
kernel_size,
stride,
pad,
groups=groups,
groups=active_groups,
bias=False,
),
),
@@ -739,19 +753,26 @@ class MBConvLayer(nn.Module):


class ResidualBlock(nn.Module):
def __init__(self, conv, shortcut):
def __init__(self, conv, shortcut, drop_connect_rate=0):
super(ResidualBlock, self).__init__()

self.conv = conv
self.mobile_inverted_conv = self.conv # BigNAS
self.shortcut = shortcut
self.drop_connect_rate = drop_connect_rate

def forward(self, x):
in_channel = x.size(1)
if self.conv is None or isinstance(self.conv, ZeroLayer):
res = x
elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
res = self.conv(x)
else:
res = self.conv(x) + self.shortcut(x)
im = self.shortcut(x)
x = self.conv(x)
if self.drop_connect_rate > 0 and in_channel == im.size(1) and self.shortcut.reduction == 1:
x = drop_connect(x, p=self.drop_connect_rate, training=self.training)
res = x + im
return res

@property
@@ -955,3 +976,52 @@ class ResNetBottleneckBlock(nn.Module):
@staticmethod
def build_from_config(config):
return ResNetBottleneckBlock(**config)


class ShortcutLayer(nn.Module):
"""
NOTE:
This class implements similar functionality to `IdentityLayer`,
but adds and removes part of the implementation.
"""
def __init__(self, in_channels, out_channels, reduction=1):
super(ShortcutLayer, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.reduction = reduction

self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)

def forward(self, x):
if self.reduction > 1:
padding = 0 if x.size(-1) % 2 == 0 else 1
x = F.avg_pool2d(x, self.reduction, padding=padding)
if self.in_channels != self.out_channels:
x = self.conv(x)
return x

@property
def module_str(self):
if self.in_channels == self.out_channels and self.reduction == 1:
conv_str = 'IdentityShortcut'
else:
if self.reduction == 1:
conv_str = '%d-%d_Shortcut' % (self.in_channels, self.out_channels)
else:
conv_str = '%d-%d_R%d_Shortcut' % (self.in_channels, self.out_channels, self.reduction)
return conv_str

@property
def config(self):
return {
'name': ShortcutLayer.__name__,
'in_channels': self.in_channels,
'out_channels': self.out_channels,
'reduction': self.reduction,
}

@staticmethod
def build_from_config(config):
return ShortcutLayer(**config)

+ 25
- 0
xnas/spaces/OFA/utils.py View File

@@ -47,6 +47,7 @@ def get_same_padding(kernel_size):
assert kernel_size % 2 > 0, "kernel size should be odd number"
return kernel_size // 2


def make_divisible(v, divisor=8, min_val=None):
"""
This function is taken from the original tf repo.
@@ -67,6 +68,30 @@ def make_divisible(v, divisor=8, min_val=None):
return new_v


def drop_connect(inputs, p, training):
"""Drop connect.
Args:
input (tensor: BCWH): Input of this structure.
p (float: 0.0~1.0): Probability of drop connection.
training (bool): The running mode.
Returns:
output: Output after drop connection.
"""
assert 0 <= p <= 1, 'p must be in range of [0,1]'
if not training:
return inputs
batch_size = inputs.shape[0]
keep_prob = 1.0 - p

# generate binary_tensor mask according to probability (p for 0, 1-p for 1)
random_tensor = keep_prob
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
binary_tensor = torch.floor(random_tensor)

output = inputs / keep_prob * binary_tensor
return output


""" BN related """

def clean_num_batch_tracked(net):


+ 271
- 0
xnas/spaces/ProxylessNAS/cnn.py View File

@@ -0,0 +1,271 @@
import json
import numpy as np
import torch.nn as nn

from xnas.spaces.OFA.ops import (
set_layer_from_config,
MBConvLayer,
ConvLayer,
IdentityLayer,
LinearLayer,
ResidualBlock,
GlobalAvgPool2d,
)
from xnas.spaces.OFA.utils import val2list, make_divisible
from xnas.spaces.OFA.MobileNetV3.cnn import WSConv_Network


__all__ = ["proxyless_base", "ProxylessNASNet", "MobileNetV2"]


def proxyless_base(
net_config=None,
n_classes=None,
bn_param=None,
dropout_rate=None,
):
assert net_config is not None, "Please input a network config"
net_config_json = json.load(open(net_config, "r"))

if n_classes is not None:
net_config_json["classifier"]["out_features"] = n_classes
if dropout_rate is not None:
net_config_json["classifier"]["dropout_rate"] = dropout_rate

net = ProxylessNASNet.build_from_config(net_config_json)
if bn_param is not None:
net.set_bn_param(*bn_param)

return net


class ProxylessNASNet(WSConv_Network):
def __init__(self, first_conv, blocks, feature_mix_layer, classifier):
super(ProxylessNASNet, self).__init__()

self.first_conv = first_conv
self.blocks = nn.ModuleList(blocks)
self.feature_mix_layer = feature_mix_layer
self.global_avg_pool = GlobalAvgPool2d(keep_dim=False)
self.classifier = classifier

def forward(self, x):
x = self.first_conv(x)
for block in self.blocks:
x = block(x)
if self.feature_mix_layer is not None:
x = self.feature_mix_layer(x)
x = self.global_avg_pool(x)
x = self.classifier(x)
return x

@property
def module_str(self):
_str = self.first_conv.module_str + "\n"
for block in self.blocks:
_str += block.module_str + "\n"
_str += self.feature_mix_layer.module_str + "\n"
_str += self.global_avg_pool.__repr__() + "\n"
_str += self.classifier.module_str
return _str

@property
def config(self):
return {
"name": ProxylessNASNet.__name__,
"bn": self.get_bn_param(),
"first_conv": self.first_conv.config,
"blocks": [block.config for block in self.blocks],
"feature_mix_layer": None
if self.feature_mix_layer is None
else self.feature_mix_layer.config,
"classifier": self.classifier.config,
}

@staticmethod
def build_from_config(config):
first_conv = set_layer_from_config(config["first_conv"])
feature_mix_layer = set_layer_from_config(config["feature_mix_layer"])
classifier = set_layer_from_config(config["classifier"])

blocks = []
for block_config in config["blocks"]:
blocks.append(ResidualBlock.build_from_config(block_config))

net = ProxylessNASNet(first_conv, blocks, feature_mix_layer, classifier)
if "bn" in config:
net.set_bn_param(**config["bn"])
else:
net.set_bn_param(momentum=0.1, eps=1e-3)

return net

def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, ResidualBlock):
if isinstance(m.conv, MBConvLayer) and isinstance(
m.shortcut, IdentityLayer
):
m.conv.point_linear.bn.weight.data.zero_()

@property
def grouped_block_index(self):
info_list = []
block_index_list = []
for i, block in enumerate(self.blocks[1:], 1):
if block.shortcut is None and len(block_index_list) > 0:
info_list.append(block_index_list)
block_index_list = []
block_index_list.append(i)
if len(block_index_list) > 0:
info_list.append(block_index_list)
return info_list

def load_state_dict(self, state_dict, **kwargs):
current_state_dict = self.state_dict()

for key in state_dict:
if key not in current_state_dict:
assert ".mobile_inverted_conv." in key
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
else:
new_key = key
current_state_dict[new_key] = state_dict[key]
super(ProxylessNASNet, self).load_state_dict(current_state_dict)


class MobileNetV2(ProxylessNASNet):
def __init__(
self,
n_classes=1000,
width_mult=1.0,
bn_param=(0.1, 1e-3),
dropout_rate=0.2,
ks=None, # a list only include {3, 5, 7}
expand_ratio=None, # in proxyless space only 3 or 6
depth_param=None,
stage_width_list=None,
):

ks = 3 if ks is None else ks
expand_ratio = [6]*6 if expand_ratio is None else expand_ratio

input_channel = 32
last_channel = 1280

input_channel = make_divisible(input_channel * width_mult)
last_channel = (
make_divisible(last_channel * width_mult)
if width_mult > 1.0
else last_channel
)

inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[0, 24, 2, 2],
[0, 32, 3, 2],
[0, 64, 4, 2],
[0, 96, 3, 1],
[0, 160, 3, 2],
[0, 320, 1, 1],
]

if depth_param is not None:
assert len(depth_param) == 6
# assert isinstance(depth_param, )
for i in range(1, len(inverted_residual_setting) - 1):
inverted_residual_setting[i][2] = depth_param[i-1]

if stage_width_list is not None:
assert len(stage_width_list) == 7
for i in range(len(inverted_residual_setting)):
inverted_residual_setting[i][1] = stage_width_list[i]

# if expand_ratio is not None:
# for i in range(len(inverted_residual_setting)):
# inverted_residual_setting[i][0] = expand_ratio[i]

# ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1)
_pt = 0

self.feature_idx = np.cumsum(depth_param)[[1, 3]]
# first conv layer
first_conv = ConvLayer(
3,
input_channel,
kernel_size=3,
stride=2,
use_bn=True,
act_func="relu6",
ops_order="weight_bn_act",
)
# inverted residual blocks
blocks = []
for t, c, n, s in inverted_residual_setting:
output_channel = make_divisible(c * width_mult)
for i in range(n):
if i == 0:
stride = s
else:
stride = 1
if t == 1: # only used for first block
kernel_size = 3
er = 1
else:
kernel_size = ks[_pt].item()
er = expand_ratio[_pt].item()
_pt += 1
mobile_inverted_conv = MBConvLayer(
in_channels=input_channel,
out_channels=output_channel,
kernel_size=kernel_size,
stride=stride,
expand_ratio=er,
)
if stride == 1:
if input_channel == output_channel:
shortcut = IdentityLayer(input_channel, input_channel)
else:
shortcut = None
else:
shortcut = None
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
input_channel = output_channel
# 1x1_conv before global average pooling
feature_mix_layer = ConvLayer(
input_channel,
last_channel,
kernel_size=1,
use_bn=True,
act_func="relu6",
ops_order="weight_bn_act",
)

classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)

super(MobileNetV2, self).__init__(
first_conv, blocks, feature_mix_layer, classifier
)

# set bn param
self.set_bn_param(*bn_param)
def forward_with_features(self, x, *args, **kwargs):
x = self.first_conv(x)
features = []
for i, block in enumerate(self.blocks):
if i in (self.feature_idx):
features.append(x)
x = block(x)
if self.feature_mix_layer is not None:
x = self.feature_mix_layer(x)
features.append(x)
assert len(features) == 3
x = self.global_avg_pool(x)
logits = self.classifier(x)
return features, logits

def _MobileNetV2():
return MobileNetV2()

Loading…
Cancel
Save