#22 dev

Merged
xfey merged 9 commits from dev into master 1 year ago
  1. +132
    -0
      configs/search/AttentiveNAS/eval.yaml
  2. +100
    -0
      configs/search/AttentiveNAS/train.yaml
  3. +89
    -0
      configs/search/BigNAS/eval.yaml
  4. +133
    -0
      configs/search/BigNAS/search.yaml
  5. +95
    -0
      configs/search/BigNAS/train.yaml
  6. +21
    -0
      configs/search/RMINAS/rminas_proxyless_imagenet.yaml
  7. +1
    -1
      examples/search/OFA/train_supernet.sh
  8. +271
    -0
      scripts/search/AttentiveNAS/train_supernet.py
  9. +186
    -0
      scripts/search/BigNAS/search.py
  10. +254
    -0
      scripts/search/BigNAS/train_supernet.py
  11. +5
    -5
      scripts/search/OFA/eval_supernet.py
  12. +8
    -8
      scripts/search/OFA/train_supernet.py
  13. +35
    -25
      scripts/search/RMINAS.py
  14. +3
    -3
      scripts/train/DARTS.py
  15. +1
    -1
      tests/ofa_matrices_test.py
  16. +117
    -0
      xnas/algorithms/AttentiveNAS/sampler.py
  17. +47
    -1
      xnas/algorithms/RMINAS/sampler/RF_sampling.py
  18. +1
    -1
      xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet.py
  19. +8
    -10
      xnas/algorithms/RMINAS/utils/random_data.py
  20. +12
    -1
      xnas/core/builder.py
  21. +6
    -1
      xnas/core/config.py
  22. +402
    -0
      xnas/datasets/auto_augment_tf.py
  23. +64
    -136
      xnas/datasets/imagenet.py
  24. +42
    -30
      xnas/datasets/loader.py
  25. +124
    -0
      xnas/datasets/transforms_imagenet.py
  26. +36
    -4
      xnas/runner/criterion.py
  27. +0
    -1
      xnas/runner/optimizer.py
  28. +4
    -2
      xnas/runner/scheduler.py
  29. +7
    -7
      xnas/runner/trainer.py
  30. +8
    -9
      xnas/runner/trainer_spos.py
  31. +652
    -0
      xnas/spaces/AttentiveNAS/cnn.py
  32. +653
    -0
      xnas/spaces/BigNAS/cnn.py
  33. +331
    -0
      xnas/spaces/BigNAS/dynamic_layers.py
  34. +181
    -0
      xnas/spaces/BigNAS/dynamic_ops.py
  35. +133
    -0
      xnas/spaces/BigNAS/ops.py
  36. +134
    -0
      xnas/spaces/BigNAS/utils.py
  37. +14
    -10
      xnas/spaces/OFA/ProxylessNet/cnn.py
  38. +8
    -26
      xnas/spaces/OFA/dynamic_ops.py
  39. +88
    -18
      xnas/spaces/OFA/ops.py
  40. +25
    -0
      xnas/spaces/OFA/utils.py
  41. +271
    -0
      xnas/spaces/ProxylessNAS/cnn.py

+ 132
- 0
configs/search/AttentiveNAS/eval.yaml View File

@@ -0,0 +1,132 @@
NUM_GPUS: 4
RNG_SEED: 2
SPACE:
NAME: 'attentivenas'
LOADER:
DATASET: 'imagenet'
NUM_CLASSES: 1000
BATCH_SIZE: 256
NUM_WORKERS: 4
USE_VAL: True
TRANSFORM: "auto_augment_tf"
SEARCH:
IM_SIZE: 224
ATTENTIVENAS:
BN_MOMENTUM: 0.1
BN_EPS: 1.e-5
POST_BN_CALIBRATION_BATCH_NUM: 64
ACTIVE_SUBNET: # chosen from following settings
# attentive_nas_a0
RESOLUTION: 192
WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792]
KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3]
EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
DEPTH: [1, 3, 3, 3, 3, 3, 1]

# # attentive_nas_a1
# RESOLUTION: 224
# WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1984]
# KERNEL_SIZE: [3, 3, 3, 5, 3, 5, 3]
# EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
# DEPTH: [1, 3, 3, 3, 3, 3, 1]

# # attentive_nas_a2
# RESOLUTION: 224
# WIDTH: [16, 16, 24, 32, 64, 112, 200, 224, 1984]
# KERNEL_SIZE: [3, 3, 3, 3, 3, 5, 3]
# EXPAND_RATIO: [1, 4, 5, 4, 4, 6, 6]
# DEPTH: [1, 3, 3, 3, 3, 4, 1]

# # attentive_nas_a3
# RESOLUTION: 224
# WIDTH: [16, 16, 24, 32, 64, 112, 208, 224, 1984]
# KERNEL_SIZE: [3, 3, 3, 5, 3, 3, 3]
# EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
# DEPTH: [2, 3, 3, 4, 3, 5, 1]

# # attentive_nas_a4
# RESOLUTION: 256
# WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1984]
# KERNEL_SIZE: [3, 3, 3, 5, 3, 5, 3]
# EXPAND_RATIO: [1, 4, 4, 5, 4, 6, 6]
# DEPTH: [1, 3, 3, 4, 3, 5, 1]

# # attentive_nas_a5
# RESOLUTION: 256
# WIDTH: [16, 16, 24, 32, 72, 112, 192, 216, 1792]
# KERNEL_SIZE: [3, 3, 3, 5, 3, 3, 3]
# EXPAND_RATIO: [1, 4, 5, 4, 4, 6, 6]
# DEPTH: [1, 3, 3, 3, 4, 6, 1]

# # attentive_nas_a6
# RESOLUTION: 288
# WIDTH: [16, 16, 24, 32, 64, 112, 216, 224, 1984]
# KERNEL_SIZE: [3, 3, 3, 3, 3, 5, 3]
# EXPAND_RATIO: [1, 4, 6, 5, 4, 6, 6]
# DEPTH: [1, 3, 3, 4, 4, 6, 1]
SUPERNET_CFG:
use_v3_head: True
resolutions: [192, 224, 256, 288]
first_conv:
c: [16, 24]
act_func: 'swish'
s: 2
mb1:
c: [16, 24]
d: [1, 2]
k: [3, 5]
t: [1]
s: 1
act_func: 'swish'
se: False
mb2:
c: [24, 32]
d: [3, 4, 5]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb3:
c: [32, 40]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: True
mb4:
c: [64, 72]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb5:
c: [112, 120, 128]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [4, 5, 6]
s: 1
act_func: 'swish'
se: True
mb6:
c: [192, 200, 208, 216]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [6]
s: 2
act_func: 'swish'
se: True
mb7:
c: [216, 224]
d: [1, 2]
k: [3, 5]
t: [6]
s: 1
act_func: 'swish'
se: True
last_conv:
c: [1792, 1984]
act_func: 'swish'

+ 100
- 0
configs/search/AttentiveNAS/train.yaml View File

@@ -0,0 +1,100 @@
NUM_GPUS: 4
RNG_SEED: 0
SPACE:
NAME: 'attentivenas'
LOADER:
DATASET: 'imagenet'
NUM_CLASSES: 1000
BATCH_SIZE: 64 # 32*8 in total
NUM_WORKERS: 4
USE_VAL: True
TRANSFORM: "auto_augment_tf"
OPTIM:
GRAD_CLIP: 1.
WARMUP_EPOCH: 5
MAX_EPOCH: 360
WEIGHT_DECAY: 1.e-5
BASE_LR: 0.2
NESTEROV: True
SEARCH:
LOSS_FUN: "cross_entropy_smooth"
LABEL_SMOOTH: 0.1
TRAIN:
DROP_PATH_PROB: 0.2
ATTENTIVENAS:
SANDWICH_NUM: 4 # max + 2*middle + min
DROP_CONNECT: 0.2
BN_MOMENTUM: 0.
BN_EPS: 1.e-5
POST_BN_CALIBRATION_BATCH_NUM: 64
SAMPLER:
METHOD: 'bestup'
MAP_PATH: 'xnas/algorithms/AttentiveNAS/flops_archs_off_table.map'
DISCRETIZE_STEP: 25
NUM_TRIALS: 3
SUPERNET_CFG:
use_v3_head: True
resolutions: [192, 224, 256, 288]
first_conv:
c: [16, 24]
act_func: 'swish'
s: 2
mb1:
c: [16, 24]
d: [1, 2]
k: [3, 5]
t: [1]
s: 1
act_func: 'swish'
se: False
mb2:
c: [24, 32]
d: [3, 4, 5]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb3:
c: [32, 40]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: True
mb4:
c: [64, 72]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb5:
c: [112, 120, 128]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [4, 5, 6]
s: 1
act_func: 'swish'
se: True
mb6:
c: [192, 200, 208, 216]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [6]
s: 2
act_func: 'swish'
se: True
mb7:
c: [216, 224]
d: [1, 2]
k: [3, 5]
t: [6]
s: 1
act_func: 'swish'
se: True
last_conv:
c: [1792, 1984]
act_func: 'swish'

+ 89
- 0
configs/search/BigNAS/eval.yaml View File

@@ -0,0 +1,89 @@
NUM_GPUS: 4
RNG_SEED: 2
SPACE:
NAME: 'bignas'
LOADER:
DATASET: 'imagenet'
NUM_CLASSES: 1000
BATCH_SIZE: 128
NUM_WORKERS: 4
USE_VAL: True
TRANSFORM: "auto_augment_tf"
SEARCH:
IM_SIZE: 224
BIGNAS:
BN_MOMENTUM: 0.1
BN_EPS: 1.e-5
POST_BN_CALIBRATION_BATCH_NUM: 64
ACTIVE_SUBNET: # subnet for evaluation
RESOLUTION: 192
WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792]
KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3]
EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
DEPTH: [1, 3, 3, 3, 3, 3, 1]
SUPERNET_CFG:
use_v3_head: True
resolutions: [192, 224, 256, 288]
first_conv:
c: [16, 24]
act_func: 'swish'
s: 2
mb1:
c: [16, 24]
d: [1, 2]
k: [3, 5]
t: [1]
s: 1
act_func: 'swish'
se: False
mb2:
c: [24, 32]
d: [3, 4, 5]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb3:
c: [32, 40]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: True
mb4:
c: [64, 72]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb5:
c: [112, 120, 128]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [4, 5, 6]
s: 1
act_func: 'swish'
se: True
mb6:
c: [192, 200, 208, 216]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [6]
s: 2
act_func: 'swish'
se: True
mb7:
c: [216, 224]
d: [1, 2]
k: [3, 5]
t: [6]
s: 1
act_func: 'swish'
se: True
last_conv:
c: [1792, 1984]
act_func: 'swish'

+ 133
- 0
configs/search/BigNAS/search.yaml View File

@@ -0,0 +1,133 @@
NUM_GPUS: 1
RNG_SEED: 2
SPACE:
NAME: 'bignas'
LOADER:
DATASET: 'imagenet'
NUM_CLASSES: 1000
BATCH_SIZE: 128
NUM_WORKERS: 8
USE_VAL: True
TRANSFORM: "auto_augment_tf"
SEARCH:
IM_SIZE: 224
WEIGHTS: "exp/search/test/checkpoints/best_model_epoch_0009.pyth"
BIGNAS:
CONSTRAINT_FLOPS: 6.e+8 # 600M
NUM_MUTATE: 200
BN_MOMENTUM: 0.1
BN_EPS: 1.e-5
POST_BN_CALIBRATION_BATCH_NUM: 64
# ACTIVE_SUBNET: # subnet for evaluation
# RESOLUTION: 192
# WIDTH: [16, 16, 24, 32, 64, 112, 192, 216, 1792]
# KERNEL_SIZE: [3, 3, 3, 3, 3, 3, 3]
# EXPAND_RATIO: [1, 4, 4, 4, 4, 6, 6]
# DEPTH: [1, 3, 3, 3, 3, 3, 1]
SEARCH_CFG_SETS:
resolutions: [224, 256]
first_conv:
c: [16]
mb1:
c: [16]
d: [2]
k: [3]
t: [1]
mb2:
c: [24]
d: [3]
k: [3]
t: [5]
mb3:
c: [32]
d: [4]
k: [3]
t: [5]
mb4:
c: [64]
d: [5]
k: [3]
t: [5]
mb5:
c: [120]
d: [6]
k: [3]
t: [5]
mb6:
c: [192]
d: [6]
k: [3, 5]
t: [6]
mb7:
c: [216]
d: [2]
k: [3]
t: [6]
last_conv:
c: [1792]
SUPERNET_CFG:
use_v3_head: True
resolutions: [192, 224, 256, 288]
first_conv:
c: [16, 24]
act_func: 'swish'
s: 2
mb1:
c: [16, 24]
d: [1, 2]
k: [3, 5]
t: [1]
s: 1
act_func: 'swish'
se: False
mb2:
c: [24, 32]
d: [3, 4, 5]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb3:
c: [32, 40]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: True
mb4:
c: [64, 72]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb5:
c: [112, 120, 128]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [4, 5, 6]
s: 1
act_func: 'swish'
se: True
mb6:
c: [192, 200, 208, 216]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [6]
s: 2
act_func: 'swish'
se: True
mb7:
c: [216, 224]
d: [1, 2]
k: [3, 5]
t: [6]
s: 1
act_func: 'swish'
se: True
last_conv:
c: [1792, 1984]
act_func: 'swish'

+ 95
- 0
configs/search/BigNAS/train.yaml View File

@@ -0,0 +1,95 @@
NUM_GPUS: 4
RNG_SEED: 0
SPACE:
NAME: 'bignas'
LOADER:
DATASET: 'imagenet'
NUM_CLASSES: 1000
BATCH_SIZE: 32
NUM_WORKERS: 4
USE_VAL: True
TRANSFORM: "auto_augment_tf"
OPTIM:
GRAD_CLIP: 1.
WARMUP_EPOCH: 5
MAX_EPOCH: 360
WEIGHT_DECAY: 1.e-5
BASE_LR: 0.1
NESTEROV: True
SEARCH:
LOSS_FUN: "cross_entropy_smooth"
LABEL_SMOOTH: 0.1
TRAIN:
DROP_PATH_PROB: 0.2
BIGNAS:
SANDWICH_NUM: 4 # max + 2*middle + min
DROP_CONNECT: 0.2
BN_MOMENTUM: 0.
BN_EPS: 1.e-5
POST_BN_CALIBRATION_BATCH_NUM: 64
SUPERNET_CFG:
use_v3_head: True
resolutions: [192, 224, 256, 288]
first_conv:
c: [16, 24]
act_func: 'swish'
s: 2
mb1:
c: [16, 24]
d: [1, 2]
k: [3, 5]
t: [1]
s: 1
act_func: 'swish'
se: False
mb2:
c: [24, 32]
d: [3, 4, 5]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb3:
c: [32, 40]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: True
mb4:
c: [64, 72]
d: [3, 4, 5, 6]
k: [3, 5]
t: [4, 5, 6]
s: 2
act_func: 'swish'
se: False
mb5:
c: [112, 120, 128]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [4, 5, 6]
s: 1
act_func: 'swish'
se: True
mb6:
c: [192, 200, 208, 216]
d: [3, 4, 5, 6, 7, 8]
k: [3, 5]
t: [6]
s: 2
act_func: 'swish'
se: True
mb7:
c: [216, 224]
d: [1, 2]
k: [3, 5]
t: [6]
s: 1
act_func: 'swish'
se: True
last_conv:
c: [1792, 1984]
act_func: 'swish'

+ 21
- 0
configs/search/RMINAS/rminas_proxyless_imagenet.yaml View File

@@ -0,0 +1,21 @@
SPACE:
NAME: 'proxyless'
LOADER:
DATASET: 'imagenet'
NUM_CLASSES: 10
NUM_WORKERS: 0
BATCH_SIZE: 128
OPTIM:
BASE_LR: 0.025
MOMENTUM: 0.9
WEIGHT_DECAY: 0.0003
MAX_EPOCH: 500
TRAIN:
CHANNELS: 16
LAYERS: 8
RMINAS:
LOSS_BETA: 0.8
RF_WARMUP: 100
RF_THRESRATE: 0.05
RF_SUCC: 100
OUT_DIR: 'exp/rminas'

+ 1
- 1
examples/search/OFA/train_supernet.sh View File

@@ -1,4 +1,4 @@
OUT_NAME="OFA_trail_25"
OUT_NAME="OFA_trial_25"
TASKS="normal_1 kernel_1 depth_1 depth_2 expand_1 expand_2" TASKS="normal_1 kernel_1 depth_1 depth_2 expand_1 expand_2"


for loop in $TASKS for loop in $TASKS


+ 271
- 0
scripts/search/AttentiveNAS/train_supernet.py View File

@@ -0,0 +1,271 @@
"""AttentiveNAS supernet training"""

import os
import random
import operator

import torch
import torch.nn as nn

import xnas.core.config as config
from xnas.datasets.loader import get_normal_dataloader
import xnas.logger.meter as meter
import xnas.logger.logging as logging
from xnas.core.config import cfg
from xnas.core.builder import *

# DDP
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

# AttentiveNAS
from xnas.runner.trainer import Trainer
from xnas.runner.scheduler import adjust_learning_rate_per_batch
from xnas.algorithms.AttentiveNAS.sampler import ArchSampler
from xnas.spaces.OFA.utils import list_mean
from xnas.spaces.BigNAS.utils import init_model


# Load config and check
config.load_configs()
logger = logging.get_logger(__name__)


def main(local_rank, world_size):
setup_env()
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', rank=local_rank, world_size=world_size)
# Network
net = space_builder().to(local_rank)
init_model(net)
# Loss function
criterion = criterion_builder()
soft_criterion = criterion_builder('kl_soft')
# Data loaders
[train_loader, valid_loader] = get_normal_dataloader()
# Optimizers
net_params = [
# parameters with weight decay
{"params": net.get_parameters(['bn', 'bias'], mode="exclude"), "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
# parameters without weight decay
{"params": net.get_parameters(['bn', 'bias'], mode="include") , "weight_decay": 0},
]
optimizer = optimizer_builder("SGD", net_params)
# sampler for AttentiveNAS
sampler = ArchSampler(
cfg.ATTENTIVENAS.SAMPLER.MAP_PATH, cfg.ATTENTIVENAS.SAMPLER.DISCRETIZE_STEP, net, None
)
net = DDP(net, device_ids=[local_rank], find_unused_parameters=True)
# Initialize Recorder
attnas_trainer = AttentivenasTrainer(
model=net,
criterion=criterion,
soft_criterion=soft_criterion,
sampler=sampler,
optimizer=optimizer,
lr_scheduler=None,
train_loader=train_loader,
test_loader=valid_loader,
)
# Resume
start_epoch = attnas_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0
# Training
logger.info("Start AttentiveNAS training.")
dist.barrier()
attnas_trainer.start()
for cur_epoch in range(start_epoch, cfg.OPTIM.WARMUP_EPOCH+cfg.OPTIM.MAX_EPOCH):
attnas_trainer.train_epoch(cur_epoch, rank=local_rank)
if local_rank == 0:
if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH:
attnas_trainer.validate(cur_epoch, local_rank)
attnas_trainer.finish()
dist.barrier()
torch.cuda.empty_cache()


class AttentivenasTrainer(Trainer):
"""Trainer for AttentiveNAS."""
def __init__(self, model, criterion, soft_criterion, sampler, optimizer, lr_scheduler, train_loader, test_loader):
super().__init__(model, criterion, optimizer, lr_scheduler, train_loader, test_loader)
self.sandwich_sample_num = max(2, cfg.ATTENTIVENAS.SANDWICH_NUM) # containing max & min
self.soft_criterion = soft_criterion
self.sampler = sampler

def train_epoch(self, cur_epoch, rank=0):
self.model.train()
# lr = self.lr_scheduler.get_last_lr()[0]
cur_step = cur_epoch * len(self.train_loader)
# self.writer.add_scalar('train/lr', lr, cur_step)
self.train_meter.iter_tic()
self.train_loader.sampler.set_epoch(cur_epoch) # DDP
for cur_iter, (inputs, labels) in enumerate(self.train_loader):
inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True)
# Adjust lr per iter
cur_lr = adjust_learning_rate_per_batch(
epoch=cur_epoch,
n_iter=len(self.train_loader),
iter=cur_iter,
warmup=(cur_epoch < cfg.OPTIM.WARMUP_EPOCH),
)
for param_group in self.optimizer.param_groups:
param_group["lr"] = cur_lr
# self.writer.add_scalar('train/lr', cur_lr, cur_step)
self.optimizer.zero_grad()
## Sandwich Rule ##
# Step 1. Largest network sampling & regularization
self.model.module.sample_max_subnet()
self.model.module.set_dropout_rate(cfg.TRAIN.DROP_PATH_PROB, cfg.ATTENTIVENAS.DROP_CONNECT)
preds = self.model(inputs)
loss = self.criterion(preds, labels)
loss.backward()
with torch.no_grad():
soft_logits = preds.clone().detach()
# Step 2. sample smaller networks
self.model.module.set_dropout_rate(0, 0)
for arch_id in range(1, self.sandwich_sample_num):
if arch_id == self.sandwich_sample_num - 1:
self.model.module.sample_min_subnet()
else:
if self.sampler is not None:
sampling_method = cfg.ATTENTIVENAS.SAMPLER.METHOD
if sampling_method in ['bestup', 'worstup']:
target_flops = self.sampler.sample_one_target_flops()
candidate_archs = self.sampler.sample_archs_according_to_flops(
target_flops, n_samples=cfg.ATTENTIVENAS.SAMPLER.NUM_TRIALS
)
my_pred_accs = []
for arch in candidate_archs:
self.model.module.set_active_subnet(**arch)
with torch.no_grad():
my_pred_accs.append(-1.0 * self.criterion(self.model(inputs), labels))
if sampling_method == 'bestup':
idx, _ = max(enumerate(my_pred_accs), key=operator.itemgetter(1))
else:
idx, _ = min(enumerate(my_pred_accs), key=operator.itemgetter(1))
self.model.module.set_active_subnet(**candidate_archs[idx]) #reset
else:
subnet_seed = int("%d%.3d%.3d" % (cur_step, arch_id, 0))
random.seed(subnet_seed)
self.model.module.sample_active_subnet()
preds = self.model(inputs)
if self.soft_criterion is not None:
loss = self.soft_criterion(preds, soft_logits)
else:
loss = self.criterion(preds, labels)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), cfg.OPTIM.GRAD_CLIP)
self.optimizer.step()
# calculating errors. The source code of AttentiveNAS uses statistics of the smallest network and XNAS follows.
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item()
self.train_meter.iter_toc()
self.train_meter.update_stats(top1_err, top5_err, loss, cur_lr, inputs.size(0) * cfg.NUM_GPUS)
self.train_meter.log_iter_stats(cur_epoch, cur_iter)
self.train_meter.iter_tic()
# self.writer.add_scalar('train/loss', i_loss, cur_step)
# self.writer.add_scalar('train/top1_error', i_top1err, cur_step)
# self.writer.add_scalar('train/top5_error', i_top5err, cur_step)
cur_step += 1
# Log epoch stats
self.train_meter.log_epoch_stats(cur_epoch)
self.train_meter.reset()
# self.lr_scheduler.step()
# Saving checkpoint
if rank==0 and (cur_epoch + 1) % cfg.SAVE_PERIOD == 0:
self.saving(cur_epoch)
@torch.no_grad()
def test_epoch(self, subnet, cur_epoch, rank=0):
subnet.eval()
self.test_meter.reset(True)
self.test_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(self.test_loader):
inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True)
preds = subnet(inputs)
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item()
self.test_meter.iter_toc()
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_avg()
top5_err = self.test_meter.mb_top5_err.get_win_avg()
# self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
# self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch)
# self.test_meter.reset()
return top1_err, top5_err


def validate(self, cur_epoch, rank, bn_calibration=True):
subnets_to_be_evaluated = {
'attentive_nas_min_net': {},
'attentive_nas_max_net': {},
}
top1_list, top5_list = [], []
with torch.no_grad():
for net_id in subnets_to_be_evaluated:
if net_id == 'attentive_nas_min_net':
self.model.module.sample_min_subnet()
elif net_id == 'attentive_nas_max_net':
self.model.module.sample_max_subnet()
elif net_id.startswith('attentive_nas_random_net'):
self.model.module.sample_active_subnet()
else:
self.model.module.set_active_subnet(
subnets_to_be_evaluated[net_id]['resolution'],
subnets_to_be_evaluated[net_id]['width'],
subnets_to_be_evaluated[net_id]['depth'],
subnets_to_be_evaluated[net_id]['kernel_size'],
subnets_to_be_evaluated[net_id]['expand_ratio'],
)

subnet = self.model.module.get_active_subnet()
subnet.to(rank)
logger.info("evaluating subnet {}".format(net_id))
if bn_calibration:
subnet.eval()
logger.info("Calibrating BN running statistics.")
subnet.reset_running_stats_for_calibration()
for cur_iter, (inputs, _) in enumerate(self.train_loader):
if cur_iter >= cfg.ATTENTIVENAS.POST_BN_CALIBRATION_BATCH_NUM:
break
inputs = inputs.to(rank)
subnet(inputs) # forward only
top1_err, top5_err = self.test_epoch(subnet, cur_epoch, rank)
top1_list.append(top1_err), top5_list.append(top5_err)
logger.info("Average@all_subnets top1_err:{} top5_err:{}".format(list_mean(top1_list), list_mean(top5_list)))
if self.best_err > list_mean(top1_list):
self.best_err = list_mean(top1_list)
self.saving(cur_epoch, best=True)


if __name__ == '__main__':
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '23333'
if torch.cuda.is_available():
cfg.NUM_GPUS = torch.cuda.device_count()
mp.spawn(main, nprocs=cfg.NUM_GPUS, args=(cfg.NUM_GPUS,), join=True)

+ 186
- 0
scripts/search/BigNAS/search.py View File

@@ -0,0 +1,186 @@
"""BigNAS subnet searching: Coarse-to-fine Architecture Selection"""

import numpy as np
from itertools import product

import torch

import xnas.core.config as config
import xnas.logger.meter as meter
import xnas.logger.logging as logging
from xnas.core.builder import *
from xnas.core.config import cfg
from xnas.datasets.loader import get_normal_dataloader
from xnas.logger.meter import TestMeter


# Load config and check
config.load_configs()
logger = logging.get_logger(__name__)


def get_all_subnets():
# get all subnets
all_subnets = []
subnet_sets = cfg.BIGNAS.SEARCH_CFG_SETS
stage_names = ['mb1', 'mb2', 'mb3', 'mb4', 'mb5', 'mb6', 'mb7']

mb_stage_subnets = []
for mbstage in stage_names:
mb_block_cfg = getattr(subnet_sets, mbstage)
mb_stage_subnets.append(list(product(
mb_block_cfg.c,
mb_block_cfg.d,
mb_block_cfg.k,
mb_block_cfg.t
)))

all_mb_stage_subnets = list(product(*mb_stage_subnets))

resolutions = getattr(subnet_sets, 'resolutions')
first_conv = getattr(subnet_sets, 'first_conv')
last_conv = getattr(subnet_sets, 'last_conv')

for res in resolutions:
for fc in first_conv.c:
for mb in all_mb_stage_subnets:
np_mb_choice = np.array(mb)
width = np_mb_choice[:, 0].tolist() # c
depth = np_mb_choice[:, 1].tolist() # d
kernel = np_mb_choice[:, 2].tolist() # k
expand = np_mb_choice[:, 3].tolist() # t
for lc in last_conv.c:
all_subnets.append({
'resolution': res,
'width': [fc] + width + [lc],
'depth': depth,
'kernel_size': kernel,
'expand_ratio': expand
})
return all_subnets


def main():
setup_env()
supernet = space_builder().cuda()
supernet.load_weights_from_pretrained_models(cfg.SEARCH.WEIGHTS)

[train_loader, valid_loader] = get_normal_dataloader()

test_meter = TestMeter(len(valid_loader))

all_subnets = get_all_subnets()
benchmarks = []

# Phase 1. coarse search
for k,subnet_cfg in enumerate(all_subnets):
supernet.set_active_subnet(
subnet_cfg['resolution'],
subnet_cfg['width'],
subnet_cfg['depth'],
subnet_cfg['kernel_size'],
subnet_cfg['expand_ratio'],
)
subnet = supernet.get_active_subnet().cuda()
# Validate
top1_err, top5_err = validate(subnet, train_loader, valid_loader, test_meter)
flops = supernet.compute_active_subnet_flops()

logger.info("[{}/{}] flops:{} top1_err:{} top5_err:{}".format(
k+1, len(all_subnets), flops, top1_err, top5_err
))

benchmarks.append({
'subnet_cfg': subnet_cfg,
'flops': flops,
'top1_err': top1_err,
'top5_err': top5_err
})

# Phase 2. fine-grained search
try:
best_subnet_info = list(filter(
lambda k: k['flops'] < cfg.BIGNAS.CONSTRAINT_FLOPS,
sorted(benchmarks, key=lambda d: d['top1_err'])))[0]
best_subnet_cfg = best_subnet_info['subnet_cfg']
best_subnet_top1 = best_subnet_info['top1_err']
except IndexError:
logger.info("Cannot find subnets under {} FLOPs".format(cfg.BIGNAS.CONSTRAINT_FLOPS))
exit(1)
for mutate_epoch in range(cfg.BIGNAS.NUM_MUTATE):
new_subnet_cfg = supernet.mutate_and_reset(best_subnet_cfg)
prev_cfgs = [i['subnet_cfg'] for i in benchmarks]
if new_subnet_cfg in prev_cfgs:
continue
subnet = supernet.get_active_subnet().cuda()
# Validate
top1_err, top5_err = validate(subnet, train_loader, valid_loader, test_meter)
flops = supernet.compute_active_subnet_flops()
logger.info("[{}/{}] flops:{} top1_err:{} top5_err:{}".format(
mutate_epoch+1, cfg.BIGNAS.NUM_MUTATE, flops, top1_err, top5_err
))

benchmarks.append({
'subnet_cfg': subnet_cfg,
'flops': flops,
'top1_err': top1_err,
'top5_err': top5_err
})
if flops < cfg.BIGNAS.CONSTRAINT_FLOPS and top1_err < best_subnet_top1:
best_subnet_cfg = new_subnet_cfg
best_subnet_top1 = top1_err
# Final best architecture
logger.info("="*20 + "\nMutate Finished.")
logger.info("Best Architecture:\n{}\n Best top1_err:{}".format(
best_subnet_cfg, best_subnet_top1
))


@torch.no_grad()
def validate(subnet, train_loader, valid_loader, test_meter):
# BN calibration
subnet.eval()
logger.info("Calibrating BN running statistics.")
subnet.reset_running_stats_for_calibration()
for cur_iter, (inputs, _) in enumerate(train_loader):
if cur_iter >= cfg.BIGNAS.POST_BN_CALIBRATION_BATCH_NUM:
break
inputs = inputs.cuda()
subnet(inputs) # forward only

top1_err, top5_err = test_epoch(subnet, valid_loader, test_meter)
return top1_err, top5_err


def test_epoch(subnet, test_loader, test_meter):
subnet.eval()
test_meter.reset(True)
test_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(test_loader):
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
preds = subnet(inputs)
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item()

test_meter.iter_toc()
test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
test_meter.log_iter_stats(0, cur_iter)
test_meter.iter_tic()
top1_err = test_meter.mb_top1_err.get_win_avg()
top5_err = test_meter.mb_top5_err.get_win_avg()
# self.writer.add_scalar('val/top1_error', test_meter.mb_top1_err.get_win_avg(), cur_epoch)
# self.writer.add_scalar('val/top5_error', test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats
test_meter.log_epoch_stats(0)
# test_meter.reset()
return top1_err, top5_err


if __name__ == "__main__":
main()

+ 254
- 0
scripts/search/BigNAS/train_supernet.py View File

@@ -0,0 +1,254 @@
"""AttentiveNAS supernet training"""

import os
import random

import torch
import torch.nn as nn

import xnas.core.config as config
from xnas.datasets.loader import get_normal_dataloader
import xnas.logger.meter as meter
import xnas.logger.logging as logging
from xnas.core.config import cfg
from xnas.core.builder import *

# DDP
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

# AttentiveNAS
from xnas.runner.trainer import Trainer
from xnas.runner.scheduler import adjust_learning_rate_per_batch
from xnas.spaces.OFA.utils import list_mean
from xnas.spaces.BigNAS.utils import init_model

# Load config and check
config.load_configs()
logger = logging.get_logger(__name__)


def main(local_rank, world_size):
setup_env()
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', rank=local_rank, world_size=world_size)
# Network
net = space_builder().to(local_rank)
init_model(net)
# Loss function
criterion = criterion_builder()
soft_criterion = criterion_builder('kl_soft')
# Data loaders
[train_loader, valid_loader] = get_normal_dataloader()
# Optimizers
net_params = [
# parameters with weight decay
{"params": net.get_parameters(['bn', 'bias'], mode="exclude"), "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
# parameters without weight decay
{"params": net.get_parameters(['bn', 'bias'], mode="include") , "weight_decay": 0},
]
optimizer = optimizer_builder("SGD", net_params)
# Rule: only regularize the biggest model
optimizer_no_wd = torch.optim.SGD(
net.parameters(),
cfg.OPTIM.BASE_LR,
cfg.OPTIM.MOMENTUM,
cfg.OPTIM.DAMPENING,
0, # no weight decay.
cfg.OPTIM.NESTEROV,
)
net = DDP(net, device_ids=[local_rank], find_unused_parameters=True)
# Initialize Recorder
bignas_trainer = BigNASTrainer(
model=net,
criterion=criterion,
soft_criterion=soft_criterion,
optimizer=optimizer,
optim_no_wd=optimizer_no_wd,
lr_scheduler=None,
train_loader=train_loader,
test_loader=valid_loader,
)
# Resume
start_epoch = bignas_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0
# Training
logger.info("Start BigNAS training.")
dist.barrier()
bignas_trainer.start()
for cur_epoch in range(start_epoch, cfg.OPTIM.WARMUP_EPOCH+cfg.OPTIM.MAX_EPOCH):
bignas_trainer.train_epoch(cur_epoch, rank=local_rank)
if local_rank == 0:
if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH:
bignas_trainer.validate(cur_epoch, local_rank)
bignas_trainer.finish()
dist.barrier()
torch.cuda.empty_cache()


class BigNASTrainer(Trainer):
"""Trainer for BigNAS."""
def __init__(self, model, criterion, soft_criterion, optimizer, optim_no_wd, lr_scheduler, train_loader, test_loader):
super().__init__(model, criterion, optimizer, lr_scheduler, train_loader, test_loader)
self.sandwich_sample_num = max(2, cfg.BIGNAS.SANDWICH_NUM) # containing max & min
self.soft_criterion = soft_criterion
self.optim_no_wd = optim_no_wd

def train_epoch(self, cur_epoch, rank=0):
self.model.train()
# lr = self.lr_scheduler.get_last_lr()[0]
cur_step = cur_epoch * len(self.train_loader)
# self.writer.add_scalar('train/lr', lr, cur_step)
self.train_meter.iter_tic()
self.train_loader.sampler.set_epoch(cur_epoch) # DDP
for cur_iter, (inputs, labels) in enumerate(self.train_loader):
inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True)
# Adjust lr per iter
cur_lr = adjust_learning_rate_per_batch(
epoch=cur_epoch,
n_iter=len(self.train_loader),
iter=cur_iter,
warmup=(cur_epoch < cfg.OPTIM.WARMUP_EPOCH),
)
# Rule: constrant ending
cur_lr = max(cur_lr, 0.05 * cfg.OPTIM.BASE_LR)
for param_group in self.optimizer.param_groups:
param_group["lr"] = cur_lr
# self.writer.add_scalar('train/lr', cur_lr, cur_step)
## Sandwich Rule ##
# Step 1. Largest network sampling & regularization
self.optimizer.zero_grad()
self.model.module.sample_max_subnet()
self.model.module.set_dropout_rate(cfg.TRAIN.DROP_PATH_PROB, cfg.BIGNAS.DROP_CONNECT)
preds = self.model(inputs)
loss = self.criterion(preds, labels)
loss.backward()
self.optimizer.step()
with torch.no_grad():
soft_logits = preds.clone().detach()
# Step 2. sample smaller networks
self.optim_no_wd.zero_grad()
self.model.module.set_dropout_rate(0, 0)
for arch_id in range(1, self.sandwich_sample_num):
if arch_id == self.sandwich_sample_num - 1:
self.model.module.sample_min_subnet()
else:
subnet_seed = int("%d%.3d%.3d" % (cur_step, arch_id, 0))
random.seed(subnet_seed)
self.model.module.sample_active_subnet()
preds = self.model(inputs)
if self.soft_criterion is not None:
loss = self.soft_criterion(preds, soft_logits)
else:
loss = self.criterion(preds, labels)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), cfg.OPTIM.GRAD_CLIP)
self.optim_no_wd.step()
# calculating errors. The source code of AttentiveNAS uses statistics of the smallest network and XNAS follows.
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item()
self.train_meter.iter_toc()
self.train_meter.update_stats(top1_err, top5_err, loss, cur_lr, inputs.size(0) * cfg.NUM_GPUS)
self.train_meter.log_iter_stats(cur_epoch, cur_iter)
self.train_meter.iter_tic()
# self.writer.add_scalar('train/loss', i_loss, cur_step)
# self.writer.add_scalar('train/top1_error', i_top1err, cur_step)
# self.writer.add_scalar('train/top5_error', i_top5err, cur_step)
cur_step += 1
# Log epoch stats
self.train_meter.log_epoch_stats(cur_epoch)
self.train_meter.reset()
# self.lr_scheduler.step()
# Saving checkpoint
if rank==0 and (cur_epoch + 1) % cfg.SAVE_PERIOD == 0:
self.saving(cur_epoch)
@torch.no_grad()
def test_epoch(self, subnet, cur_epoch, rank=0):
subnet.eval()
self.test_meter.reset(True)
self.test_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(self.test_loader):
inputs, labels = inputs.to(rank), labels.to(rank, non_blocking=True)
preds = subnet(inputs)
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item()
self.test_meter.iter_toc()
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_avg()
top5_err = self.test_meter.mb_top5_err.get_win_avg()
# self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
# self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch)
# self.test_meter.reset()
return top1_err, top5_err


def validate(self, cur_epoch, rank, bn_calibration=True):
subnets_to_be_evaluated = {
'bignas_min_net': {},
'bignas_max_net': {},
}
top1_list, top5_list = [], []
with torch.no_grad():
for net_id in subnets_to_be_evaluated:
if net_id == 'bignas_min_net':
self.model.module.sample_min_subnet()
elif net_id == 'bignas_max_net':
self.model.module.sample_max_subnet()
elif net_id.startswith('bignas_random_net'):
self.model.module.sample_active_subnet()
else:
self.model.module.set_active_subnet(
subnets_to_be_evaluated[net_id]['resolution'],
subnets_to_be_evaluated[net_id]['width'],
subnets_to_be_evaluated[net_id]['depth'],
subnets_to_be_evaluated[net_id]['kernel_size'],
subnets_to_be_evaluated[net_id]['expand_ratio'],
)

subnet = self.model.module.get_active_subnet()
subnet.to(rank)
logger.info("evaluating subnet {}".format(net_id))
if bn_calibration:
subnet.eval()
logger.info("Calibrating BN running statistics.")
subnet.reset_running_stats_for_calibration()
for cur_iter, (inputs, _) in enumerate(self.train_loader):
if cur_iter >= cfg.BIGNAS.POST_BN_CALIBRATION_BATCH_NUM:
break
inputs = inputs.to(rank)
subnet(inputs) # forward only
top1_err, top5_err = self.test_epoch(subnet, cur_epoch, rank)
top1_list.append(top1_err), top5_list.append(top5_err)
logger.info("Average@all_subnets top1_err:{} top5_err:{}".format(list_mean(top1_list), list_mean(top5_list)))
if self.best_err > list_mean(top1_list):
self.best_err = list_mean(top1_list)
self.saving(cur_epoch, best=True)


if __name__ == '__main__':
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '23333'
if torch.cuda.is_available():
cfg.NUM_GPUS = torch.cuda.device_count()
mp.spawn(main, nprocs=cfg.NUM_GPUS, args=(cfg.NUM_GPUS,), join=True)

+ 5
- 5
scripts/search/OFA/eval_supernet.py View File

@@ -96,7 +96,7 @@ def main():
# load_last_stage_ckpt(cfg.OFA.TASK, cfg.OFA.PHASE) # load_last_stage_ckpt(cfg.OFA.TASK, cfg.OFA.PHASE)
# ofa_trainer.resume() # only load the state_dict of model # ofa_trainer.resume() # only load the state_dict of model


# cfg.SEARCH.WEIGHTS = '/home/xfey/XNAS/exp/search/OFA_trail_25/kernel_1/checkpoints/model_epoch_0110.pyth'
# cfg.SEARCH.WEIGHTS = '/home/xfey/XNAS/exp/search/OFA_trial_25/kernel_1/checkpoints/model_epoch_0110.pyth'
cfg.SEARCH.WEIGHTS = '/home/xfey/XNAS/tests/weights/ofa_D4_E6_K357' cfg.SEARCH.WEIGHTS = '/home/xfey/XNAS/tests/weights/ofa_D4_E6_K357'
ofa_trainer.resume() ofa_trainer.resume()


@@ -137,10 +137,10 @@ class OFATrainer(KDTrainer):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic() self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
top5_err = self.test_meter.mb_top5_err.get_win_median()
# self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
# self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
top5_err = self.test_meter.mb_top5_err.get_win_avg()
# self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
# self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats # Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch) self.test_meter.log_epoch_stats(cur_epoch)
# self.test_meter.reset() # self.test_meter.reset()


+ 8
- 8
scripts/search/OFA/train_supernet.py View File

@@ -9,6 +9,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F


import xnas.core.config as config import xnas.core.config as config
from xnas.datasets.loader import get_normal_dataloader
import xnas.logger.meter as meter import xnas.logger.meter as meter
import xnas.logger.logging as logging import xnas.logger.logging as logging
from xnas.core.config import cfg from xnas.core.config import cfg
@@ -44,7 +45,7 @@ def main(local_rank, world_size):
# Loss function # Loss function
criterion = criterion_builder() criterion = criterion_builder()
# Data loaders # Data loaders
[train_loader, valid_loader] = construct_loader()
[train_loader, valid_loader] = get_normal_dataloader()
# Optimizers # Optimizers
net_params = [ net_params = [
# parameters with weight decay # parameters with weight decay
@@ -241,10 +242,10 @@ class OFATrainer(KDTrainer):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic() self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
top5_err = self.test_meter.mb_top5_err.get_win_median()
# self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
# self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
top5_err = self.test_meter.mb_top5_err.get_win_avg()
# self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
# self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats # Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch) self.test_meter.log_epoch_stats(cur_epoch)
# self.test_meter.reset() # self.test_meter.reset()
@@ -320,8 +321,8 @@ class OFATrainer(KDTrainer):
logger.info("Average@all_subnets top1_err:{} top5_err:{}".format(list_mean(top1errs), list_mean(top5errs))) logger.info("Average@all_subnets top1_err:{} top5_err:{}".format(list_mean(top1errs), list_mean(top5errs)))
# Saving best model # Saving best model
if self.best_err > top1_err:
self.best_err = top1_err
if self.best_err > list_mean(top1errs):
self.best_err = list_mean(top1errs)
self.saving(cur_epoch, best=True) self.saving(cur_epoch, best=True)




@@ -331,6 +332,5 @@ if __name__ == '__main__':
if torch.cuda.is_available(): if torch.cuda.is_available():
cfg.NUM_GPUS = torch.cuda.device_count() cfg.NUM_GPUS = torch.cuda.device_count()
print(cfg.NUM_GPUS)
mp.spawn(main, nprocs=cfg.NUM_GPUS, args=(cfg.NUM_GPUS,), join=True) mp.spawn(main, nprocs=cfg.NUM_GPUS, args=(cfg.NUM_GPUS,), join=True)

+ 35
- 25
scripts/search/RMINAS.py View File

@@ -4,7 +4,7 @@ import time
import numpy as np import numpy as np


import torch import torch
from fvcore.nn import FlopCountAnalysis
import xnas.core.config as config import xnas.core.config as config
import xnas.logger.logging as logging import xnas.logger.logging as logging
from xnas.core.config import cfg from xnas.core.config import cfg
@@ -38,6 +38,8 @@ def rminas_hp_builder():
RF_space = 'nasbenchmacro' RF_space = 'nasbenchmacro'
from xnas.evaluations.NASBenchMacro.evaluate import evaluate, data from xnas.evaluations.NASBenchMacro.evaluate import evaluate, data
api = data api = data
elif cfg.SPACE.NAME == 'proxyless':
RF_space = 'proxyless'
# for example : arch = '00000000' # for example : arch = '00000000'
# arch = '' # arch = ''
# evaluate(arch) # evaluate(arch)
@@ -48,7 +50,7 @@ def main():
rminas_hp_builder() rminas_hp_builder()
assert cfg.SPACE.NAME in ['infer_nb201', 'infer_darts',"nasbenchmacro"]
assert cfg.SPACE.NAME in ['infer_nb201', 'infer_darts',"nasbenchmacro", "proxyless"]
assert cfg.LOADER.DATASET in ['cifar10', 'cifar100', 'imagenet', 'imagenet16_120'], 'dataset error' assert cfg.LOADER.DATASET in ['cifar10', 'cifar100', 'imagenet', 'imagenet16_120'], 'dataset error'


if cfg.LOADER.DATASET == 'cifar10': if cfg.LOADER.DATASET == 'cifar10':
@@ -64,7 +66,7 @@ def main():
network.load_state_dict(torch.load('xnas/algorithms/RMINAS/teacher_model/resnet101_cifar100/resnet101.pth')) network.load_state_dict(torch.load('xnas/algorithms/RMINAS/teacher_model/resnet101_cifar100/resnet101.pth'))


elif cfg.LOADER.DATASET == 'imagenet': elif cfg.LOADER.DATASET == 'imagenet':
assert cfg.SPACE.NAME == 'infer_darts'
assert cfg.SPACE.NAME in ('infer_darts', 'proxyless')
logger.warning('Our method does not directly search in ImageNet.') logger.warning('Our method does not directly search in ImageNet.')
logger.warning('Only partial tests have been conducted, please use with caution.') logger.warning('Only partial tests have been conducted, please use with caution.')
import xnas.algorithms.RMINAS.teacher_model.fbresnet_imagenet.fbresnet as fbresnet import xnas.algorithms.RMINAS.teacher_model.fbresnet_imagenet.fbresnet as fbresnet
@@ -93,8 +95,8 @@ def main():
ce_loss = torch.nn.CrossEntropyLoss(reduction='none').cuda() ce_loss = torch.nn.CrossEntropyLoss(reduction='none').cuda()
more_logits = network(more_data_X) more_logits = network(more_data_X)
_, indices = torch.topk(-ce_loss(more_logits, more_data_y).cpu().detach(), cfg.LOADER.BATCH_SIZE) _, indices = torch.topk(-ce_loss(more_logits, more_data_y).cpu().detach(), cfg.LOADER.BATCH_SIZE)
data_y = torch.Tensor([more_data_y[i] for i in indices]).long().cuda()
data_X = torch.Tensor([more_data_X[i].cpu().numpy() for i in indices]).cuda()
data_y = more_data_y.detach()
data_X = more_data_X.detach()
with torch.no_grad(): with torch.no_grad():
feature_res = network.feature_extractor(data_X) feature_res = network.feature_extractor(data_X)
@@ -107,6 +109,7 @@ def main():
loss_fun_log = torch.nn.CrossEntropyLoss().cuda() loss_fun_log = torch.nn.CrossEntropyLoss().cuda()
def train_arch(modelinfo): def train_arch(modelinfo):
flops = None
if cfg.SPACE.NAME == 'infer_nb201': if cfg.SPACE.NAME == 'infer_nb201':
# get arch # get arch
arch_config = { arch_config = {
@@ -122,6 +125,12 @@ def main():
elif cfg.SPACE.NAME == 'nasbenchmacro': elif cfg.SPACE.NAME == 'nasbenchmacro':
model = space_builder().cuda() model = space_builder().cuda()
optimizer = optimizer_builder("SGD", model.parameters()) optimizer = optimizer_builder("SGD", model.parameters())
elif cfg.SPACE.NAME == 'proxyless':
model = space_builder(stage_width_list=[16, 24, 40, 80, 96, 192, 320],depth_param=modelinfo[:6],ks=modelinfo[6:27][modelinfo[6:27]>0],expand_ratio=modelinfo[27:][modelinfo[27:]>0],dropout_rate=0).cuda()
optimizer = optimizer_builder("SGD", model.parameters())
with torch.no_grad():
tensor = (torch.rand(1, 3, 224, 224).cuda(),)
flops = FlopCountAnalysis(model, tensor).total()
# lr_scheduler = lr_scheduler_builder(optimizer) # lr_scheduler = lr_scheduler_builder(optimizer)


# nbm_trainer = OneShotTrainer( # nbm_trainer = OneShotTrainer(
@@ -150,12 +159,14 @@ def main():
optimizer.step() optimizer.step()
epoch_losses.append(loss.detach().cpu().item()) epoch_losses.append(loss.detach().cpu().item())
if cur_epoch == cfg.OPTIM.MAX_EPOCH: if cur_epoch == cfg.OPTIM.MAX_EPOCH:
return loss.cpu().detach().numpy(), epoch_losses

return loss.cpu().detach().numpy(), {'epoch_losses':epoch_losses, 'flops':flops}

trained_arch_darts, trained_loss = [], [] trained_arch_darts, trained_loss = [], []
def train_procedure(sample): def train_procedure(sample):
if cfg.SPACE.NAME == 'infer_nb201': if cfg.SPACE.NAME == 'infer_nb201':
mixed_loss = train_arch(sample)[0]
mixed_loss, epoch_losses = train_arch(sample)[0]
mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss
trained_loss.append(mixed_loss) trained_loss.append(mixed_loss)
arch_arr = sampling.nb201genostr2array(api.arch(sample)) arch_arr = sampling.nb201genostr2array(api.arch(sample))
@@ -164,17 +175,25 @@ def main():
elif cfg.SPACE.NAME == 'infer_darts': elif cfg.SPACE.NAME == 'infer_darts':
sample_geno = geno_from_alpha(sampling.darts_sug2alpha(sample)) # type=Genotype sample_geno = geno_from_alpha(sampling.darts_sug2alpha(sample)) # type=Genotype
trained_arch_darts.append(str(sample_geno)) trained_arch_darts.append(str(sample_geno))
mixed_loss = train_arch(sample_geno)[0]
mixed_loss, epoch_losses = train_arch(sample_geno)[0]
mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss
trained_loss.append(mixed_loss) trained_loss.append(mixed_loss)
RFS.trained_arch.append({'arch':sample, 'loss':mixed_loss}) RFS.trained_arch.append({'arch':sample, 'loss':mixed_loss})
elif cfg.SPACE.NAME == 'nasbenchmacro': elif cfg.SPACE.NAME == 'nasbenchmacro':
sample_geno = ''.join(sample.astype('str')) # type=Genotype sample_geno = ''.join(sample.astype('str')) # type=Genotype
trained_arch_darts.append((sample_geno)) trained_arch_darts.append((sample_geno))
mixed_loss, epoch_losses = train_arch(sample)
mixed_loss, info = train_arch(sample)
mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss
trained_loss.append(mixed_loss)
RFS.trained_arch.append({'arch':sample, 'loss':mixed_loss,'gt':api[sample_geno]['mean_acc'],'losses':info["epoch_losses"]})
elif cfg.SPACE.NAME == 'proxyless':
sample_geno = ''.join(sample.astype('str')) # type=Genotype
trained_arch_darts.append((sample_geno))
mixed_loss, info = train_arch(sample)
mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss mixed_loss = np.inf if np.isnan(mixed_loss) else mixed_loss
trained_loss.append(mixed_loss) trained_loss.append(mixed_loss)
RFS.trained_arch.append({'arch':sample, 'loss':mixed_loss,'gt':api[sample_geno]['mean_acc'],'losses':epoch_losses})
RFS.trained_arch.append({'arch':sample, 'loss':mixed_loss,'gt':info["flops"],'losses':info["epoch_losses"]})

logger.info("sample: {}, loss:{}".format(sample, mixed_loss)) logger.info("sample: {}, loss:{}".format(sample, mixed_loss))
@@ -185,21 +204,6 @@ def main():
for sample in warmup_samples: for sample in warmup_samples:
train_procedure(sample) train_procedure(sample)
RFS.Warmup() RFS.Warmup()
logger.info('warmup time cost: {}'.format(str(time.time() - start_time)))
# with open('./rmi_nbm.pkl','wb') as f:
# pickle.dump(RFS.trained_arch,f)
# gt = np.array([_['gt'] for _ in RFS.trained_arch])
# losses = np.array([_['losses'] for _ in RFS.trained_arch])
# from scipy.stats import kendalltau
# kdts = []
# for epoch in range(losses.shape[-1]):
# kdts.append(kendalltau(gt, -losses[:, epoch]).correlation)
# import matplotlib.pyplot as plt
# plt.plot(kdts)
# plt.xlabel('epoch')
# plt.ylabel('kdt')
# plt.savefig('rmi_nbm.png')
# sys.exit()
# ====== RF Sampling ====== # ====== RF Sampling ======
sampling_time = time.time() sampling_time = time.time()
sampling_cnt= 0 sampling_cnt= 0
@@ -231,6 +235,12 @@ def main():
# op_geno = reformat_DARTS(geno_from_alpha(op_alpha)) # op_geno = reformat_DARTS(geno_from_alpha(op_alpha))
logger.info('Searched architecture@top50:\n{}'.format(str(op_sample))) logger.info('Searched architecture@top50:\n{}'.format(str(op_sample)))
print(api[op_sample]['mean_acc']) print(api[op_sample]['mean_acc'])
elif cfg.SPACE.NAME == 'proxyless':
op_sample = RFS.optimal_arch(method='sum', top=100)
# op_alpha = torch.from_numpy(np.r_[op_sample, op_sample])
# op_geno = reformat_DARTS(geno_from_alpha(op_alpha))
logger.info('Searched architecture@top100:\n{}'.format(str(op_sample)))
print(api[op_sample]['mean_acc'])


if __name__ == '__main__': if __name__ == '__main__':
main() main()

+ 3
- 3
scripts/train/DARTS.py View File

@@ -97,9 +97,9 @@ class Darts_Retrainer(Trainer):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic() self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats # Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch) self.test_meter.log_epoch_stats(cur_epoch)
self.test_meter.reset() self.test_meter.reset()


+ 1
- 1
tests/ofa_matrices_test.py View File

@@ -2,7 +2,7 @@ import os
import torch import torch


def test_local(): def test_local():
root = '/home/xfey/XNAS/exp/search/OFA_trail_25/kernel_1/checkpoints/'
root = '/home/xfey/XNAS/exp/search/OFA_trial_25/kernel_1/checkpoints/'
filename_prefix = 'model_epoch_' filename_prefix = 'model_epoch_'
filename_postfix = '.pyth' filename_postfix = '.pyth'




+ 117
- 0
xnas/algorithms/AttentiveNAS/sampler.py View File

@@ -0,0 +1,117 @@
import random

def count_helper(v, flops, m):
if flops not in m:
m[flops] = {}
if v not in m[flops]:
m[flops][v] = 0
m[flops][v] += 1


def round_flops(flops, step):
return int(round(flops / step) * step)


def convert_count_to_prob(m):
if isinstance(m[list(m.keys())[0]], dict):
for k in m:
convert_count_to_prob(m[k])
else:
t = sum(m.values())
for k in m:
m[k] = 1.0 * m[k] / t


def sample_helper(flops, m):
keys = list(m[flops].keys())
probs = list(m[flops].values())
return random.choices(keys, weights=probs)[0]


def build_trasition_prob_matrix(file_handler, step):
# initlizie
prob_map = {}
prob_map['discretize_step'] = step
for k in ['flops', 'resolution', 'width', 'depth', 'kernel_size', 'expand_ratio']:
prob_map[k] = {}

cc = 0
for line in file_handler:
vals = eval(line.strip())

# discretize
flops = round_flops(vals['flops'], step)
prob_map['flops'][flops] = prob_map['flops'].get(flops, 0) + 1

# resolution
r = vals['resolution']
count_helper(r, flops, prob_map['resolution'])

for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
for idx, v in enumerate(vals[k]):
if idx not in prob_map[k]:
prob_map[k][idx] = {}
count_helper(v, flops, prob_map[k][idx])

cc += 1

# convert count to probability
for k in ['flops', 'resolution', 'width', 'depth', 'kernel_size', 'expand_ratio']:
convert_count_to_prob(prob_map[k])
prob_map['n_observations'] = cc
return prob_map



class ArchSampler():
def __init__(self, arch_to_flops_map_file_path, discretize_step, model, acc_predictor=None):
super(ArchSampler, self).__init__()

with open(arch_to_flops_map_file_path, 'r') as fp:
self.prob_map = build_trasition_prob_matrix(fp, discretize_step)

self.discretize_step = discretize_step
self.model = model

self.acc_predictor = acc_predictor

self.min_flops = min(list(self.prob_map['flops'].keys()))
self.max_flops = max(list(self.prob_map['flops'].keys()))

self.curr_sample_pool = None #TODO; architecture samples could be generated in an asynchronous way


def sample_one_target_flops(self, flops_uniform=False):
f_vals = list(self.prob_map['flops'].keys())
f_probs = list(self.prob_map['flops'].values())

if flops_uniform:
return random.choice(f_vals)
else:
return random.choices(f_vals, weights=f_probs)[0]


def sample_archs_according_to_flops(self, target_flops, n_samples=1, max_trials=100, return_flops=True, return_trials=False):
archs = []
#for _ in range(n_samples):
while len(archs) < n_samples:
for _trial in range(max_trials+1):
arch = {}
arch['resolution'] = sample_helper(target_flops, self.prob_map['resolution'])
for k in ['width', 'kernel_size', 'depth', 'expand_ratio']:
arch[k] = []
for idx in sorted(list(self.prob_map[k].keys())):
arch[k].append(sample_helper(target_flops, self.prob_map[k][idx]))
if self.model:
self.model.set_active_subnet(**arch)
flops = self.model.compute_active_subnet_flops()
if return_flops:
arch['flops'] = flops
if round_flops(flops, self.discretize_step) == target_flops:
break
else:
raise NotImplementedError
#accepte the sample anyway
archs.append(arch)
return archs


+ 47
- 1
xnas/algorithms/RMINAS/sampler/RF_sampling.py View File

@@ -35,7 +35,8 @@ class RF_suggest():
self.max_space = int(3**8) self.max_space = int(3**8)
self.num_estimator = 30 self.num_estimator = 30
self.spaces = list(api.keys()) self.spaces = list(api.keys())
elif self.space == 'proxyless':
self.num_estimator = 100
self.model = RandomForestClassifier(n_estimators=self.num_estimator,random_state=seed) self.model = RandomForestClassifier(n_estimators=self.num_estimator,random_state=seed)
def _update_lossthres(self): def _update_lossthres(self):
@@ -74,6 +75,8 @@ class RF_suggest():
return [self._single_sample() for _ in range(num_warmup)] return [self._single_sample() for _ in range(num_warmup)]
elif self.space == 'nasbenchmacro': elif self.space == 'nasbenchmacro':
return [self._single_sample() for _ in range(num_warmup)] return [self._single_sample() for _ in range(num_warmup)]
elif self.space == 'proxyless':
return [self._single_sample() for _ in range(num_warmup)]
def _single_sample(self, unique=True): def _single_sample(self, unique=True):
if self.space == 'nasbench201': if self.space == 'nasbench201':
@@ -125,6 +128,28 @@ class RF_suggest():
else: else:
numeric_choice = np.random.randint(3,size=8) numeric_choice = np.random.randint(3,size=8)
return numeric_choice return numeric_choice
elif self.space == 'proxyless':
def gen_sample():
depth = np.array(np.random.randint(1, 4+1, size=5).tolist() + [1])
anchors = depth+[0,4,8,12,16,20]
ks = np.random.choice([3,5,7], size=21)
expand_ratios = np.random.choice([3,6], size=21)
ed = 4
for anchor in anchors:
ks[anchor:ed] = 0
expand_ratios[anchor:ed] = 0
ed += 4
sample = np.concatenate([depth, ks, expand_ratios])
return sample
if unique:
while True:
sample = gen_sample()
if sample.tobytes() not in self.sampled_history:
self.sampled_history.append(sample.tobytes())
return sample
else:
sample = gen_sample()
return sample


def Warmup(self): def Warmup(self):
@@ -177,6 +202,18 @@ class RF_suggest():
for i in _sample_indexes: for i in _sample_indexes:
if self.spaces[i] not in chace_table: if self.spaces[i] not in chace_table:
_sample_archs.append(np.array(list(self.spaces[i])).astype(int)) _sample_archs.append(np.array(list(self.spaces[i])).astype(int))
elif self.space == 'proxyless':
_sample_batch = np.array([self._single_sample(unique=False).ravel() for _ in range(self.batch)])
_tmp_trained_arch = [(i['arch'].tobytes()) for i in self.trained_arch]
_sample_archs = []
for i in _sample_batch:
if (i).tobytes() not in _tmp_trained_arch:
_sample_archs.append(i)
# print("sample {} archs/batch, cost time: {}".format(len(_sample_archs), time.time()-start_time))
best_id = np.argmax(self.model.predict_proba(_sample_archs)[:,1])
best_arch = _sample_archs[best_id]
return best_arch
# _sample_batch = np.array([self._single_sample(unique=True).ravel() for _ in range(self.batch)]) # _sample_batch = np.array([self._single_sample(unique=True).ravel() for _ in range(self.batch)])
# _tmp_trained_arch = [str(i['arch'].ravel()) for i in self.trained_arch] # _tmp_trained_arch = [str(i['arch'].ravel()) for i in self.trained_arch]
# _sample_archs = [] # _sample_archs = []
@@ -311,3 +348,12 @@ class RF_suggest():
op_arr = np.zeros((_tmp_np.size, 3)) op_arr = np.zeros((_tmp_np.size, 3))
op_arr[np.arange(_tmp_np.size),_tmp_np] = 1 op_arr[np.arange(_tmp_np.size),_tmp_np] = 1
return op_arr.argmax(-1) return op_arr.argmax(-1)
elif self.space == 'proxyless':
assert method == 'sum', 'only sum is supported in mb.'
depth = estimate_archs[:, :6]
best_depth = np.eye(4)[depth].argmax(-1)+1
ks = estimate_archs[:, 6:27]//2 # {3, 5, 7}
best_ks = np.eye(3)[ks].argmax(-1) * 2 + 3
er = estimate_archs[:, 27:]//3 # {3, 6}
best_er = np.eye(2)[er].agrmax(-1) * 3 + 3
return np.concatenate([best_depth, best_ks, best_er])

+ 1
- 1
xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet.py View File

@@ -7,7 +7,7 @@ import math
import torch.utils.model_zoo as model_zoo import torch.utils.model_zoo as model_zoo
import torch import torch


WEIGHT_PATH = 'teacher_model/fbresnet_imagenet/fbresnet152.pth'
WEIGHT_PATH = 'xnas/algorithms/RMINAS/teacher_model/fbresnet_imagenet/fbresnet152.pth'


__all__ = ['FBResNet', __all__ = ['FBResNet',
#'fbresnet18', 'fbresnet34', 'fbresnet50', 'fbresnet101', #'fbresnet18', 'fbresnet34', 'fbresnet50', 'fbresnet101',


+ 8
- 10
xnas/algorithms/RMINAS/utils/random_data.py View File

@@ -1,5 +1,6 @@
import torch import torch
import random import random
import numpy as np
from xnas.datasets.loader import get_normal_dataloader from xnas.datasets.loader import get_normal_dataloader
from xnas.datasets.imagenet import ImageFolder from xnas.datasets.imagenet import ImageFolder


@@ -8,18 +9,15 @@ def get_random_data(batchsize, name):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if name == 'imagenet': if name == 'imagenet':
train_loader, _ = ImageFolder( train_loader, _ = ImageFolder(
"./data/imagenet/ILSVRC2012_img_train/",
[0.5, 0.5],
batchsize*16,
datapath="./data/imagenet/ILSVRC2012_img_train/",
batch_size=batchsize*16,
split=[0.5, 0.5],
).generate_data_loader() ).generate_data_loader()
else: else:
train_loader, _ = get_normal_dataloader(name, batchsize*16) train_loader, _ = get_normal_dataloader(name, batchsize*16)
target_i = random.randint(0, len(train_loader)-1)
more_data_X, more_data_y = None, None
for i, (more_data_X, more_data_y) in enumerate(train_loader):
if i == target_i:
break
more_data_X = more_data_X.to(device)
more_data_y = more_data_y.to(device)
random_idxs = np.random.randint(0, len(train_loader.dataset), size=train_loader.batch_size)
(more_data_X, more_data_y) = zip(*[train_loader.dataset[idx] for idx in random_idxs])
more_data_X = torch.stack(more_data_X, dim=0).to(device)
more_data_y = torch.Tensor(more_data_y).long().to(device)
return more_data_X, more_data_y return more_data_X, more_data_y

+ 12
- 1
xnas/core/builder.py View File

@@ -25,6 +25,7 @@ from xnas.runner.optimizer import optimizer_builder
from xnas.runner.criterion import criterion_builder from xnas.runner.criterion import criterion_builder
from xnas.runner.scheduler import lr_scheduler_builder from xnas.runner.scheduler import lr_scheduler_builder



__all__ = [ __all__ = [
'construct_loader', 'construct_loader',
'optimizer_builder', 'optimizer_builder',
@@ -48,9 +49,12 @@ from xnas.spaces.DrNAS.darts_cnn import _DrNAS_DARTS_CNN
from xnas.spaces.DrNAS.nb201_cnn import _DrNAS_nb201_CNN, _GDAS_nb201_CNN from xnas.spaces.DrNAS.nb201_cnn import _DrNAS_nb201_CNN, _GDAS_nb201_CNN
from xnas.spaces.SPOS.cnn import _SPOS_CNN, _infer_SPOS_CNN from xnas.spaces.SPOS.cnn import _SPOS_CNN, _infer_SPOS_CNN
from xnas.spaces.DropNAS.cnn import _DropNASCNN from xnas.spaces.DropNAS.cnn import _DropNASCNN
from xnas.spaces.ProxylessNAS.cnn import _MobileNetV2
from xnas.spaces.OFA.MobileNetV3.ofa_cnn import _OFAMobileNetV3 from xnas.spaces.OFA.MobileNetV3.ofa_cnn import _OFAMobileNetV3
from xnas.spaces.OFA.ProxylessNet.ofa_cnn import _OFAProxylessNASNet from xnas.spaces.OFA.ProxylessNet.ofa_cnn import _OFAProxylessNASNet
from xnas.spaces.OFA.ResNets.ofa_cnn import _OFAResNet from xnas.spaces.OFA.ResNets.ofa_cnn import _OFAResNet
from xnas.spaces.BigNAS.cnn import _BigNAS_CNN, _infer_BigNAS_CNN
from xnas.spaces.AttentiveNAS.cnn import _AttentiveNAS_CNN, _infer_AttentiveNAS_CNN
from xnas.spaces.NASBenchMacro.cnn import _NBMacro_child_train, _NBMacro_sup_train from xnas.spaces.NASBenchMacro.cnn import _NBMacro_child_train, _NBMacro_sup_train


SUPPORTED_SPACES = { SUPPORTED_SPACES = {
@@ -63,15 +67,22 @@ SUPPORTED_SPACES = {
"gdas_nb201": _GDAS_nb201_CNN, "gdas_nb201": _GDAS_nb201_CNN,
"dropnas": _DropNASCNN, "dropnas": _DropNASCNN,
"spos": _SPOS_CNN, "spos": _SPOS_CNN,
"spos_nb201": _SPOS_nb201_CNN,
"nasbenchmacro": _NBMacro_sup_train, "nasbenchmacro": _NBMacro_sup_train,
"ofa_mbv3": _OFAMobileNetV3, "ofa_mbv3": _OFAMobileNetV3,
"ofa_proxyless": _OFAProxylessNASNet, "ofa_proxyless": _OFAProxylessNASNet,
"ofa_resnet": _OFAResNet, "ofa_resnet": _OFAResNet,
# models for inference
"attentivenas": _AttentiveNAS_CNN,
"bignas": _BigNAS_CNN,
# ===== models for inference =====
"infer_darts": _infer_DartsCNN, "infer_darts": _infer_DartsCNN,
"infer_nb201": _infer_NASBench201, "infer_nb201": _infer_NASBench201,
"infer_spos": _infer_SPOS_CNN, "infer_spos": _infer_SPOS_CNN,
"infer_attentivenas": _infer_AttentiveNAS_CNN,
# "infer_bignas": _infer_BigNAS_CNN,
"spos_nb201": _SPOS_nb201_CNN, "spos_nb201": _SPOS_nb201_CNN,
# "proxyless": _SuperProxylessNASNets,
"proxyless": _MobileNetV2,
} }






+ 6
- 1
xnas/core/config.py View File

@@ -37,6 +37,9 @@ _C.LOADER.PIN_MEMORY = True
# _C.LOADER.BATCH_SIZE = [256, 128] # _C.LOADER.BATCH_SIZE = [256, 128]
_C.LOADER.BATCH_SIZE = 256 _C.LOADER.BATCH_SIZE = 256


# augment type using by ImageNet only
# chosen from ['default', 'auto_augment_tf']
_C.LOADER.TRANSFORM = "default"




# ------------------------------------------------------------------------------------ # # ------------------------------------------------------------------------------------ #
@@ -150,7 +153,9 @@ _C.TEST = CfgNode(new_allowed=True)


_C.TEST.IM_SIZE = 224 _C.TEST.IM_SIZE = 224


_C.TEST.BATCH_SIZE = 128
# using specific batchsize for testing
# using search.batch_size if this value keeps -1
_C.TEST.BATCH_SIZE = -1








+ 402
- 0
xnas/datasets/auto_augment_tf.py View File

@@ -0,0 +1,402 @@
""" Auto Augment
Implementation adapted from timm: https://github.com/rwightman/pytorch-image-models
"""

import random
import math
from PIL import Image, ImageOps, ImageEnhance
import PIL


_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])

_FILL = (128, 128, 128)

# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
_MAX_LEVEL = 10.

_HPARAMS_DEFAULT = dict(
translate_const=250,
img_mean=_FILL,
)

_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC)


def _interpolation(kwargs):
interpolation = kwargs.pop('resample', Image.NEAREST)
if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation)
else:
return interpolation


def _check_args_tf(kwargs):
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
kwargs.pop('fillcolor')
kwargs['resample'] = _interpolation(kwargs)


def shear_x(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)


def shear_y(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)


def translate_x_rel(img, pct, **kwargs):
pixels = pct * img.size[0]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)


def translate_y_rel(img, pct, **kwargs):
pixels = pct * img.size[1]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)


def translate_x_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)


def translate_y_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)


def rotate(img, degrees, **kwargs):
_check_args_tf(kwargs)
if _PIL_VER >= (5, 2):
return img.rotate(degrees, **kwargs)
elif _PIL_VER >= (5, 0):
w, h = img.size
post_trans = (0, 0)
rotn_center = (w / 2.0, h / 2.0)
angle = -math.radians(degrees)
matrix = [
round(math.cos(angle), 15),
round(math.sin(angle), 15),
0.0,
round(-math.sin(angle), 15),
round(math.cos(angle), 15),
0.0,
]

def transform(x, y, matrix):
(a, b, c, d, e, f) = matrix
return a * x + b * y + c, d * x + e * y + f

matrix[2], matrix[5] = transform(
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
)
matrix[2] += rotn_center[0]
matrix[5] += rotn_center[1]
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
else:
return img.rotate(degrees, resample=kwargs['resample'])


def auto_contrast(img, **__):
return ImageOps.autocontrast(img)


def invert(img, **__):
return ImageOps.invert(img)


def equalize(img, **__):
return ImageOps.equalize(img)


def solarize(img, thresh, **__):
return ImageOps.solarize(img, thresh)


def solarize_add(img, add, thresh=128, **__):
lut = []
for i in range(256):
if i < thresh:
lut.append(min(255, i + add))
else:
lut.append(i)
if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut
return img.point(lut)
else:
return img


def posterize(img, bits_to_keep, **__):
if bits_to_keep >= 8:
return img
bits_to_keep = max(1, bits_to_keep) # prevent all 0 images
return ImageOps.posterize(img, bits_to_keep)


def contrast(img, factor, **__):
return ImageEnhance.Contrast(img).enhance(factor)


def color(img, factor, **__):
return ImageEnhance.Color(img).enhance(factor)


def brightness(img, factor, **__):
return ImageEnhance.Brightness(img).enhance(factor)


def sharpness(img, factor, **__):
return ImageEnhance.Sharpness(img).enhance(factor)


def _randomly_negate(v):
"""With 50% prob, negate the value"""
return -v if random.random() > 0.5 else v


def _rotate_level_to_arg(level):
# range [-30, 30]
level = (level / _MAX_LEVEL) * 30.
level = _randomly_negate(level)
return (level,)


def _enhance_level_to_arg(level):
# range [0.1, 1.9]
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)


def _shear_level_to_arg(level):
# range [-0.3, 0.3]
level = (level / _MAX_LEVEL) * 0.3
level = _randomly_negate(level)
return (level,)


def _translate_abs_level_to_arg(level, translate_const):
level = (level / _MAX_LEVEL) * float(translate_const)
level = _randomly_negate(level)
return (level,)

def _translate_abs_level_to_arg2(level):
level = (level / _MAX_LEVEL) * float(_HPARAMS_DEFAULT['translate_const'])
level = _randomly_negate(level)
return (level,)

def _translate_rel_level_to_arg(level):
# range [-0.45, 0.45]
level = (level / _MAX_LEVEL) * 0.45
level = _randomly_negate(level)
return (level,)


# def level_to_arg(hparams):
# return {
# 'AutoContrast': lambda level: (),
# 'Equalize': lambda level: (),
# 'Invert': lambda level: (),
# 'Rotate': _rotate_level_to_arg,
# # FIXME these are both different from original impl as I believe there is a bug,
# # not sure what is the correct alternative, hence 2 options that look better
# 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8]
# 'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0]
# 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256]
# 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110]
# 'Color': _enhance_level_to_arg,
# 'Contrast': _enhance_level_to_arg,
# 'Brightness': _enhance_level_to_arg,
# 'Sharpness': _enhance_level_to_arg,
# 'ShearX': _shear_level_to_arg,
# 'ShearY': _shear_level_to_arg,
# 'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
# 'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
# 'TranslateXRel': lambda level: _translate_rel_level_to_arg(level),
# 'TranslateYRel': lambda level: _translate_rel_level_to_arg(level),
# }


NAME_TO_OP = {
'AutoContrast': auto_contrast,
'Equalize': equalize,
'Invert': invert,
'Rotate': rotate,
'Posterize': posterize,
'Posterize2': posterize,
'Solarize': solarize,
'SolarizeAdd': solarize_add,
'Color': color,
'Contrast': contrast,
'Brightness': brightness,
'Sharpness': sharpness,
'ShearX': shear_x,
'ShearY': shear_y,
'TranslateX': translate_x_abs,
'TranslateY': translate_y_abs,
'TranslateXRel': translate_x_rel,
'TranslateYRel': translate_y_rel,
}


def pass_fn(input):
return ()


def _conversion0(input):
return (int((input / _MAX_LEVEL) * 4) + 4,)


def _conversion1(input):
return (4 - int((input / _MAX_LEVEL) * 4),)


def _conversion2(input):
return (int((input / _MAX_LEVEL) * 256),)


def _conversion3(input):
return (int((input / _MAX_LEVEL) * 110),)


class AutoAugmentOp:
def __init__(self, name, prob, magnitude, hparams={}):
self.aug_fn = NAME_TO_OP[name]
# self.level_fn = level_to_arg(hparams)[name]
if name == 'AutoContrast' or name == 'Equalize' or name == 'Invert':
self.level_fn = pass_fn
elif name == 'Rotate':
self.level_fn = _rotate_level_to_arg
elif name == 'Posterize':
self.level_fn = _conversion0
elif name == 'Posterize2':
self.level_fn = _conversion1
elif name == 'Solarize':
self.level_fn = _conversion2
elif name == 'SolarizeAdd':
self.level_fn = _conversion3
elif name == 'Color' or name == 'Contrast' or name == 'Brightness' or name == 'Sharpness':
self.level_fn = _enhance_level_to_arg
elif name == 'ShearX' or name == 'ShearY':
self.level_fn = _shear_level_to_arg
elif name == 'TranslateX' or name == 'TranslateY':
self.level_fn = _translate_abs_level_to_arg2
elif name == 'TranslateXRel' or name == 'TranslateYRel':
self.level_fn = _translate_rel_level_to_arg
else:
print("{} not recognized".format({}))
self.prob = prob
self.magnitude = magnitude
# If std deviation of magnitude is > 0, we introduce some randomness
# in the usually fixed policy and sample magnitude from normal dist
# with mean magnitude and std-dev of magnitude_std.
# NOTE This is being tested as it's not in paper or reference impl.
self.magnitude_std = 0.5 # FIXME add arg/hparam
self.kwargs = {
'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL,
'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION
}

def __call__(self, img):
if self.prob < random.random():
return img
magnitude = self.magnitude
if self.magnitude_std and self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std)
magnitude = min(_MAX_LEVEL, max(0, magnitude))
level_args = self.level_fn(magnitude)
return self.aug_fn(img, *level_args, **self.kwargs)


def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
# ImageNet policy from TPU EfficientNet impl, cannot find
# a paper reference.
policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
return pc


def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT):
# ImageNet policy from https://arxiv.org/abs/1805.09501
policy = [
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
]
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
return pc


def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT):
if name == 'original':
return auto_augment_policy_original(hparams)
elif name == 'v0':
return auto_augment_policy_v0(hparams)
else:
print("Unknown auto_augmentation policy {}".format(name))
raise AssertionError()


class AutoAugment:

def __init__(self, policy):
self.policy = policy

def __call__(self, img):
sub_policy = random.choice(self.policy)
for op in sub_policy:
img = op(img)
return img

+ 64
- 136
xnas/datasets/imagenet.py View File

@@ -1,99 +1,67 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""ImageNet dataset.""" """ImageNet dataset."""


import math
import os import os
import re import re


import numpy as np import numpy as np
import torch import torch
import torch.utils.data import torch.utils.data
import torchvision.transforms as torch_transforms
from PIL import Image from PIL import Image
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler


import xnas.logger.logging as logging import xnas.logger.logging as logging
from xnas.core.config import cfg from xnas.core.config import cfg
from xnas.datasets.transforms import MultiSizeRandomCrop
from xnas.datasets.transforms_imagenet import get_data_transform




logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)




class ImageFolder(): class ImageFolder():
def __init__(
self,
datapath,
split,
batch_size=None,
dataset_name='imagenet',
_rgb_normalized_mean=None,
_rgb_normalized_std=None,
transforms=None,
num_workers=None,
pin_memory=None,
shuffle=True
):
"""New ImageFolder
Support ImageNet only currently.
"""
def __init__(self, datapath, batch_size, split=None, use_val=False, augment_type='default', **kwargs):
datapath = './data/imagenet/' if not datapath else datapath datapath = './data/imagenet/' if not datapath else datapath
assert os.path.exists(datapath), "Data path '{}' not found".format(datapath) assert os.path.exists(datapath), "Data path '{}' not found".format(datapath)
self.use_val = cfg.LOADER.USE_VAL
self._data_path, self._split, self.dataset_name = datapath, split, dataset_name
self._rgb_normalized_mean, self._rgb_normalized_std = _rgb_normalized_mean, _rgb_normalized_std
self.num_workers = cfg.LOADER.NUM_WORKERS if num_workers is None else num_workers
self.pin_memory = cfg.LOADER.PIN_MEMORY if pin_memory is None else pin_memory
self.shuffle = shuffle
self.use_val = use_val
self.data_path = datapath
self.split = split
self.batch_size = batch_size self.batch_size = batch_size
if not self.use_val:
assert sum(self.split) == 1, "Summation of split should be 1"
self.msrc = None
self.loader = torch.utils.data.DataLoader
# self.collate_fn = None
self.num_workers = cfg.LOADER.NUM_WORKERS
self.pin_memory = cfg.LOADER.PIN_MEMORY
self.augment_type = augment_type
self.kwargs = kwargs
if transforms is None:
im_size = cfg.SEARCH.IM_SIZE if len(cfg.SEARCH.MULTI_SIZES)==0 else cfg.SEARCH.MULTI_SIZES
self.transforms = [{'crop': 'random', 'crop_size': im_size, 'min_crop': 0.08, 'random_flip': True},
{'crop': 'center', 'crop_size': cfg.TEST.IM_SIZE, 'min_crop': -1, 'random_flip': False}] # NOTE: min_crop is not used here.
else:
self.transforms = transforms
if not self.use_val: if not self.use_val:
assert len(self.transforms) == len(self._split), "Length of split and transforms should be equal"
else:
assert len(self.transforms) == 2
assert sum(self.split) == 1, "Summation of split should be 1."
# Check if using multisize_random_crop
if len(cfg.SEARCH.MULTI_SIZES):
# setting default loader if not using MultiSizeRandomCrop
if len(cfg.SEARCH.MULTI_SIZES) == 0:
self.loader = torch.utils.data.DataLoader
else:
from xnas.datasets.utils.msrc_loader import msrc_DataLoader from xnas.datasets.utils.msrc_loader import msrc_DataLoader
self.msrc = MultiSizeRandomCrop(cfg.SEARCH.MULTI_SIZES)
self.loader = msrc_DataLoader self.loader = msrc_DataLoader
logger.info("Using Random MultiSize Crop, continuous={} candidate im_sizes={}".format(self.msrc.CONTINUOUS, self.msrc.CANDIDATE_SIZES))
logger.info("Using MultiSize RandomCrop, continuous={} candidate im_sizes={}".format(self.msrc.CONTINUOUS, self.msrc.CANDIDATE_SIZES))


# Read all dataset
# Acquiring transforms
logger.info("Constructing transforms")
self.train_transform, self.test_transform = self._build_transfroms()
# Read all datasets
logger.info("Constructing ImageFolder") logger.info("Constructing ImageFolder")
self._construct_imdb() self._construct_imdb()

def _construct_imdb(self): def _construct_imdb(self):
# Images are stored per class in subdirs (format: n<number>) # Images are stored per class in subdirs (format: n<number>)
if not self.use_val: if not self.use_val:
split_files = os.listdir(self._data_path)
else:
split_files = os.listdir(os.path.join(self._data_path, "train"))
if self.dataset_name == "imagenet":
# imagenet format folder names
self._class_ids = sorted(
f for f in split_files if re.match(r"^n[0-9]+$", f))
self._rgb_normalized_mean = [0.485, 0.456, 0.406]
self._rgb_normalized_std = [0.229, 0.224, 0.225]
elif self.dataset_name == 'custom':
self._class_ids = sorted(
f for f in split_files if not f[0] == '.')
split_files = os.listdir(self.data_path)
else: else:
raise NotImplementedError
split_files = os.listdir(os.path.join(self.data_path, "train"))
# imagenet format folder names
self._class_ids = sorted(
f for f in split_files if re.match(r"^n[0-9]+$", f))


# Map class ids to contiguous ids # Map class ids to contiguous ids
self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)} self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)}
@@ -102,7 +70,7 @@ class ImageFolder():
self._imdb = [] self._imdb = []
for class_id in self._class_ids: for class_id in self._class_ids:
cont_id = self._class_id_cont_id[class_id] cont_id = self._class_id_cont_id[class_id]
train_im_dir = os.path.join(self._data_path, class_id)
train_im_dir = os.path.join(self.data_path, class_id)
for im_name in os.listdir(train_im_dir): for im_name in os.listdir(train_im_dir):
im_path = os.path.join(train_im_dir, im_name) im_path = os.path.join(train_im_dir, im_name)
if is_image_file(im_path): if is_image_file(im_path):
@@ -112,8 +80,8 @@ class ImageFolder():
else: else:
self._train_imdb = [] self._train_imdb = []
self._val_imdb = [] self._val_imdb = []
train_path = os.path.join(self._data_path, "train")
val_path = os.path.join(self._data_path, "val")
train_path = os.path.join(self.data_path, "train")
val_path = os.path.join(self.data_path, "val")
for class_id in self._class_ids: for class_id in self._class_ids:
cont_id = self._class_id_cont_id[class_id] cont_id = self._class_id_cont_id[class_id]
train_im_dir = os.path.join(train_path, class_id) train_im_dir = os.path.join(train_path, class_id)
@@ -129,7 +97,11 @@ class ImageFolder():
logger.info("Number of classes: {}".format(len(self._class_ids))) logger.info("Number of classes: {}".format(len(self._class_ids)))
logger.info("Number of TRAIN images: {}".format(len(self._train_imdb))) logger.info("Number of TRAIN images: {}".format(len(self._train_imdb)))
logger.info("Number of VAL images: {}".format(len(self._val_imdb))) logger.info("Number of VAL images: {}".format(len(self._val_imdb)))

def _build_transfroms(self):
# KWARGS for 'auto_augment_tf': policy='v0', interpolation='bilinear'
return get_data_transform(augment=self.augment_type, **self.kwargs)
def generate_data_loader(self): def generate_data_loader(self):
if not self.use_val: if not self.use_val:
indices = list(range(len(self._imdb))) indices = list(range(len(self._imdb)))
@@ -138,21 +110,22 @@ class ImageFolder():
data_loaders = [] data_loaders = []
pre_partition = 0. pre_partition = 0.
pre_index = 0 pre_index = 0
for i, _split in enumerate(self._split):
for i, _split in enumerate(self.split):
_current_partition = pre_partition + _split _current_partition = pre_partition + _split
_current_index = int(len(self._imdb) * _current_partition) _current_index = int(len(self._imdb) * _current_partition)
_current_indices = indices[pre_index: _current_index] _current_indices = indices[pre_index: _current_index]
assert not len(_current_indices) == 0, "The length of indices is zero!" assert not len(_current_indices) == 0, "The length of indices is zero!"
dataset = ImageList_torch([self._imdb[j] for j in _current_indices],
self.msrc, # add support for multisize_random_crop
_rgb_normalized_mean=self._rgb_normalized_mean,
_rgb_normalized_std=self._rgb_normalized_std,
**self.transforms[i])
dataset = ImageList_torch(
[self._imdb[j] for j in _current_indices],
# using the first split only as training dataset
transform=self.train_transform if i==0 else self.test_transform
)
sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
loader = self.loader(dataset, loader = self.loader(dataset,
batch_size=self.batch_size[i], batch_size=self.batch_size[i],
shuffle=(False if sampler else True), shuffle=(False if sampler else True),
sampler=sampler, sampler=sampler,
drop_last=(True if i==0 else False),
num_workers=self.num_workers, num_workers=self.num_workers,
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
data_loaders.append(loader) data_loaders.append(loader)
@@ -160,82 +133,37 @@ class ImageFolder():
pre_index = _current_index pre_index = _current_index
return data_loaders return data_loaders
else: else:
train_dataset = ImageList_torch(
self._train_imdb,
self.msrc,
_rgb_normalized_mean=self._rgb_normalized_mean,
_rgb_normalized_std=self._rgb_normalized_std,
**self.transforms[0]
)
sampler = DistributedSampler(train_dataset) if cfg.NUM_GPUS > 1 else None
train_dataset = ImageList_torch(self._train_imdb, self.train_transform)
train_sampler = DistributedSampler(train_dataset) if cfg.NUM_GPUS > 1 else None
train_loader = self.loader(train_dataset, train_loader = self.loader(train_dataset,
batch_size=self.batch_size[0],
shuffle=(False if sampler else True),
sampler=sampler,
num_workers=self.num_workers,
pin_memory=self.pin_memory)
val_dataset = ImageList_torch(
self._val_imdb,
self.msrc,
_rgb_normalized_mean=self._rgb_normalized_mean,
_rgb_normalized_std=self._rgb_normalized_std,
**self.transforms[1]
)
sampler = DistributedSampler(val_dataset) if cfg.NUM_GPUS > 1 else None
batch_size=self.batch_size[0],
shuffle=(False if train_sampler else True),
sampler=train_sampler,
drop_last=True,
num_workers=self.num_workers,
pin_memory=self.pin_memory)
val_dataset = ImageList_torch(self._val_imdb, self.test_transform)
val_sampler = DistributedSampler(val_dataset) if cfg.NUM_GPUS > 1 else None
valid_loader = self.loader(val_dataset, valid_loader = self.loader(val_dataset,
batch_size=self.batch_size[1],
shuffle=(False if sampler else True),
sampler=sampler,
num_workers=self.num_workers,
pin_memory=self.pin_memory)
batch_size=self.batch_size[1],
shuffle=(False if val_sampler else True),
sampler=val_sampler,
drop_last=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory)
return [train_loader, valid_loader] return [train_loader, valid_loader]



class ImageList_torch(torch.utils.data.Dataset): class ImageList_torch(torch.utils.data.Dataset):
''' '''
ImageList dataloader with torch backends ImageList dataloader with torch backends
From https://github.com/pytorch/vision/issues/81 From https://github.com/pytorch/vision/issues/81
''' '''


def __init__(
self,
_list,
msrc=None,
_rgb_normalized_mean=None,
_rgb_normalized_std=None,
crop='random',
crop_size=224,
min_crop=0.08,
random_flip=False):
self._imdb = _list
self.msrc = msrc
self._bgr_normalized_mean = _rgb_normalized_mean[::-1]
self._bgr_normalized_std = _rgb_normalized_std[::-1]
self.crop = crop
self.crop_size = crop_size
self.min_crop = min_crop
self.random_flip = random_flip
def __init__(self, list, transform):
self._imdb = list
self.transform = transform
self.loader = pil_loader self.loader = pil_loader
self._construct_transforms()

def _construct_transforms(self):
transforms = []
if self.crop == "random":
if isinstance(self.crop_size, int):
transforms.append(torch_transforms.RandomResizedCrop(self.crop_size, scale=(self.min_crop, 1.0)))
elif isinstance(self.crop_size, list):
# using MultiSizeRandomCrop
transforms.append(self.msrc)
elif self.crop == "center":
transforms.append(torch_transforms.Resize(math.ceil(self.crop_size / 0.875))) # assert crop_size==224
transforms.append(torch_transforms.CenterCrop(self.crop_size))
# TODO: color augmentation support
# transforms.append(torch_transforms.ColorJitter(brightness=0.4, contrast=0.4,saturation=0.4, hue=0.2))
if self.random_flip:
transforms.append(torch_transforms.RandomHorizontalFlip())
transforms.append(torch_transforms.ToTensor())
transforms.append(torch_transforms.Normalize(mean=self._bgr_normalized_mean, std=self._bgr_normalized_std))
self.transform = torch_transforms.Compose(transforms)


def __getitem__(self, index): def __getitem__(self, index):
impath = self._imdb[index]["im_path"] impath = self._imdb[index]["im_path"]


+ 42
- 30
xnas/datasets/loader.py View File

@@ -26,21 +26,22 @@ def construct_loader(
cutout_length=0, cutout_length=0,
use_classes=None, use_classes=None,
transforms=None, transforms=None,
**kwargs
): ):
"""Construct NAS dataloaders with train&valid subsets.""" """Construct NAS dataloaders with train&valid subsets."""
split = cfg.LOADER.SPLIT split = cfg.LOADER.SPLIT
name = cfg.LOADER.DATASET name = cfg.LOADER.DATASET
batch_size = cfg.LOADER.BATCH_SIZE
batch_size = cfg.LOADER.BATCH_SIZE
datapath = cfg.LOADER.DATAPATH datapath = cfg.LOADER.DATAPATH
assert (name in SUPPORTED_DATASETS) or (name in IMAGEFOLDER_FORMAT), "dataset not supported." assert (name in SUPPORTED_DATASETS) or (name in IMAGEFOLDER_FORMAT), "dataset not supported."


# expand batch_size to support different number during training & validating # expand batch_size to support different number during training & validating
if isinstance(batch_size, int): if isinstance(batch_size, int):
batch_size = [batch_size, batch_size]
batch_size = [batch_size] * len(split)
elif batch_size is None: elif batch_size is None:
batch_size = [256, 256]
batch_size = [256] * len(split)
assert len(batch_size) == len(split), "lengths of batch_size and split should be same." assert len(batch_size) == len(split), "lengths of batch_size and split should be same."
# check if randomresized crop is used only in ImageFolder type datasets # check if randomresized crop is used only in ImageFolder type datasets
@@ -52,9 +53,11 @@ def construct_loader(
train_data, _ = get_data(name, datapath, cutout_length, use_classes=use_classes, transforms=transforms) train_data, _ = get_data(name, datapath, cutout_length, use_classes=use_classes, transforms=transforms)
return split_dataloader(train_data, batch_size, split) return split_dataloader(train_data, batch_size, split)
elif name in IMAGEFOLDER_FORMAT: elif name in IMAGEFOLDER_FORMAT:
assert cfg.LOADER.USE_VAL is False, "do not using VAL dataset."
aug_type = cfg.LOADER.TRANSFORM
return ImageFolder( # using path of training data of ImageNet as `datapath` return ImageFolder( # using path of training data of ImageNet as `datapath`
datapath, split, batch_size=batch_size,
transforms=transforms,
datapath, batch_size=batch_size, split=split,
use_val=False, augment_type=aug_type, **kwargs
).generate_data_loader() ).generate_data_loader()
else: else:
print("dataset not supported.") print("dataset not supported.")
@@ -137,39 +140,48 @@ def get_normal_dataloader(
name=None, name=None,
train_batch=None, train_batch=None,
cutout_length=0, cutout_length=0,
download=True,
use_classes=None, use_classes=None,
transforms=None, transforms=None,
**kwargs
): ):
name=cfg.LOADER.DATASET if name is None else name name=cfg.LOADER.DATASET if name is None else name
train_batch=cfg.LOADER.BATCH_SIZE if train_batch is None else train_batch train_batch=cfg.LOADER.BATCH_SIZE if train_batch is None else train_batch
name=cfg.LOADER.DATASET name=cfg.LOADER.DATASET
root=cfg.LOADER.DATAPATH
test_batch=cfg.TEST.BATCH_SIZE
datapath=cfg.LOADER.DATAPATH
test_batch=cfg.LOADER.BATCH_SIZE if cfg.TEST.BATCH_SIZE == -1 else cfg.TEST.BATCH_SIZE
# get normal dataloaders with train&test subsets.
train_data, test_data = get_data(name, root, cutout_length, download, use_classes, transforms)
# if loader.batch_size is a list for [train, val_1, ...], the first value will be used.
if isinstance(train_batch, list):
train_batch = train_batch[0]
train_loader = data.DataLoader(
dataset=train_data,
batch_size=train_batch,
shuffle=True,
num_workers=cfg.LOADER.NUM_WORKERS,
pin_memory=cfg.LOADER.PIN_MEMORY,
)
test_loader = data.DataLoader(
dataset=test_data,
batch_size=test_batch,
shuffle=False,
num_workers=cfg.LOADER.NUM_WORKERS,
pin_memory=cfg.LOADER.PIN_MEMORY,
)
return train_loader, test_loader
assert (name in SUPPORTED_DATASETS) or (name in IMAGEFOLDER_FORMAT), "dataset not supported."
assert isinstance(train_batch, int), "normal dataloader using single training batch-size, not list."
# check if randomresized crop is used only in ImageFolder type datasets
if len(cfg.SEARCH.MULTI_SIZES):
assert name in IMAGEFOLDER_FORMAT, "RandomResizedCrop can only be used in ImageFolder currently."


if name in SUPPORTED_DATASETS:
# get normal dataloaders with train&test subsets.
train_data, test_data = get_data(name, datapath, cutout_length, use_classes=use_classes, transforms=transforms)
train_loader = data.DataLoader(
dataset=train_data,
batch_size=train_batch,
shuffle=True,
num_workers=cfg.LOADER.NUM_WORKERS,
pin_memory=cfg.LOADER.PIN_MEMORY,
)
test_loader = data.DataLoader(
dataset=test_data,
batch_size=test_batch,
shuffle=False,
num_workers=cfg.LOADER.NUM_WORKERS,
pin_memory=cfg.LOADER.PIN_MEMORY,
)
return train_loader, test_loader
elif name in IMAGEFOLDER_FORMAT:
assert cfg.LOADER.USE_VAL is True, "getting normal dataloader."
aug_type = cfg.LOADER.TRANSFORM
return ImageFolder( # using path of training data of ImageNet as `datapath`
datapath, batch_size=[train_batch, test_batch],
use_val=True, augment_type=aug_type, **kwargs
).generate_data_loader()


def split_dataloader(data_, batch_size, split): def split_dataloader(data_, batch_size, split):
assert 0 not in split, "illegal split list with zero." assert 0 not in split, "illegal split list with zero."


+ 124
- 0
xnas/datasets/transforms_imagenet.py View File

@@ -0,0 +1,124 @@
import math
import torch
from PIL import Image
import torchvision.transforms as transforms

from xnas.core.config import cfg
from xnas.datasets.auto_augment_tf import auto_augment_policy, AutoAugment


IMAGENET_RGB_MEAN = [0.485, 0.456, 0.406]
IMAGENET_RGB_STD = [0.229, 0.224, 0.225]


def get_data_transform(augment, **kwargs):
if len(cfg.SEARCH.MULTI_SIZES)==0:
# using single image_size for training
train_crop_size = cfg.SEARCH.IM_SIZE
else:
# using MultiSize_RandomCrop
train_crop_size = cfg.SEARCH.MULTI_SIZES
min_train_scale = 0.08
test_scale = math.ceil(cfg.TEST.IM_SIZE / 0.875) # 224 / 0.875 = 256
test_crop_size = cfg.TEST.IM_SIZE # do not crop and using 224 by default.

interpolation = transforms.InterpolationMode.BICUBIC
if 'interpolation' in kwargs.keys() and kwargs['interpolation'] == 'bilinear':
interpolation = transforms.InterpolationMode.BILINEAR
da_args = {
'train_crop_size': train_crop_size,
'train_min_scale': min_train_scale,
'test_scale': test_scale,
'test_crop_size': test_crop_size,
'interpolation': interpolation,
}

if augment == 'default':
return build_default_transform(**da_args)
elif augment == 'auto_augment_tf':
policy = 'v0' if 'policy' not in kwargs.keys() else kwargs['policy']
return build_imagenet_auto_augment_tf_transform(policy=policy, **da_args)
else:
raise ValueError(augment)


def get_normalize():
return transforms.Normalize(
mean=torch.Tensor(IMAGENET_RGB_MEAN),
std=torch.Tensor(IMAGENET_RGB_STD),
)


def get_randomResizedCrop(train_crop_size=224, train_min_scale=0.08, interpolation=transforms.InterpolationMode.BICUBIC):
if isinstance(train_crop_size, int):
return transforms.RandomResizedCrop(train_crop_size, scale=(train_min_scale, 1.0), interpolation=interpolation)
elif isinstance(train_crop_size, list):
from xnas.datasets.transforms import MultiSizeRandomCrop
msrc = MultiSizeRandomCrop(train_crop_size)
return msrc
else:
raise TypeError(train_crop_size)


def build_default_transform(
train_crop_size=224, train_min_scale=0.08, test_scale=256, test_crop_size=224, interpolation=transforms.InterpolationMode.BICUBIC
):
normalize = get_normalize()
train_crop_transform = get_randomResizedCrop(
train_crop_size, train_min_scale, interpolation
)
train_transform = transforms.Compose(
[
# transforms.RandomResizedCrop(train_crop_size, interpolation=interpolation),
train_crop_transform,
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
)
test_transform = transforms.Compose(
[
transforms.Resize(test_scale, interpolation=interpolation),
transforms.CenterCrop(test_crop_size),
transforms.ToTensor(),
normalize,
]
)
return train_transform, test_transform


def build_imagenet_auto_augment_tf_transform(
policy='v0', train_crop_size=224, train_min_scale=0.08, test_scale=256, test_crop_size=224, interpolation=transforms.InterpolationMode.BICUBIC
):

normalize = get_normalize()
img_size = train_crop_size
aa_params = {
"translate_const": int(img_size * 0.45),
"img_mean": tuple(round(x) for x in IMAGENET_RGB_MEAN),
}

aa_policy = AutoAugment(auto_augment_policy(policy, aa_params))
train_crop_transform = get_randomResizedCrop(
train_crop_size, train_min_scale, interpolation
)
train_transform = transforms.Compose(
[
# transforms.RandomResizedCrop(train_crop_size, interpolation=interpolation),
train_crop_transform,
transforms.RandomHorizontalFlip(),
aa_policy,
transforms.ToTensor(),
normalize,
]
)
test_transform = transforms.Compose(
[
transforms.Resize(test_scale, interpolation=interpolation),
transforms.CenterCrop(test_crop_size),
transforms.ToTensor(),
normalize,
]
)
return train_transform, test_transform

+ 36
- 4
xnas/runner/criterion.py View File

@@ -32,6 +32,35 @@ def CrossEntropyLoss_label_smoothed(pred, target, label_smoothing=0.):
return CrossEntropyLoss_soft_target(pred, soft_target) return CrossEntropyLoss_soft_target(pred, soft_target)




class KLLossSoft(torch.nn.modules.loss._Loss):
""" inplace distillation for image classification
output: output logits of the student network
target: output logits of the teacher network
T: temperature
KL(p||q) = Ep \log p - \Ep log q
"""
def forward(self, output, soft_logits, target=None, temperature=1., alpha=0.9):
output, soft_logits = output / temperature, soft_logits / temperature
soft_target_prob = F.softmax(soft_logits, dim=1)
output_log_prob = F.log_softmax(output, dim=1)
kd_loss = -torch.sum(soft_target_prob * output_log_prob, dim=1)
if target is not None:
n_class = output.size(1)
target = torch.zeros_like(output).scatter(1, target.view(-1, 1), 1)
target = target.unsqueeze(1)
output_log_prob = output_log_prob.unsqueeze(2)
ce_loss = -torch.bmm(target, output_log_prob).squeeze()
loss = alpha * temperature * temperature * kd_loss + (1.0 - alpha) * ce_loss
else:
loss = kd_loss
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
return loss


class MultiHeadCrossEntropyLoss(nn.Module): class MultiHeadCrossEntropyLoss(nn.Module):
def forward(self, preds, targets): def forward(self, preds, targets):
assert preds.dim() == 3, preds assert preds.dim() == 3, preds
@@ -50,12 +79,15 @@ class MultiHeadCrossEntropyLoss(nn.Module):


SUPPORTED_CRITERIONS = { SUPPORTED_CRITERIONS = {
"cross_entropy": torch.nn.CrossEntropyLoss(), "cross_entropy": torch.nn.CrossEntropyLoss(),
"cross_entropy_soft": CrossEntropyLoss_soft_target,
"cross_entropy_smooth": CrossEntropyLoss_label_smoothed, "cross_entropy_smooth": CrossEntropyLoss_label_smoothed,
"cross_entropy_multihead": MultiHeadCrossEntropyLoss()
"cross_entropy_multihead": MultiHeadCrossEntropyLoss(),
"kl_soft": KLLossSoft(),
} }




def criterion_builder():
def criterion_builder(criterion=None):
err_str = "Loss function type '{}' not supported" err_str = "Loss function type '{}' not supported"
assert cfg.SEARCH.LOSS_FUN in SUPPORTED_CRITERIONS.keys(), err_str.format(cfg.SEARCH.LOSS_FUN)
return SUPPORTED_CRITERIONS[cfg.SEARCH.LOSS_FUN]
loss_fun = cfg.SEARCH.LOSS_FUN if criterion is None else criterion
assert loss_fun in SUPPORTED_CRITERIONS.keys(), err_str.format(loss_fun)
return SUPPORTED_CRITERIONS[loss_fun]

+ 0
- 1
xnas/runner/optimizer.py View File

@@ -1,7 +1,6 @@
"""Optimizers.""" """Optimizers."""


import torch import torch
import torch.nn as nn
from xnas.core.config import cfg from xnas.core.config import cfg






+ 4
- 2
xnas/runner/scheduler.py View File

@@ -85,14 +85,15 @@ class GradualWarmupScheduler(_LRScheduler):
def _calc_learning_rate( def _calc_learning_rate(
init_lr, n_epochs, epoch, n_iter=None, iter=0, init_lr, n_epochs, epoch, n_iter=None, iter=0,
): ):
if cfg.SEARCH.LOSS_FUN.startswith("cross_entropy"):
if cfg.OPTIM.LR_POLICY == "cos":
t_total = n_epochs * n_iter t_total = n_epochs * n_iter
t_cur = epoch * n_iter + iter t_cur = epoch * n_iter + iter
lr = 0.5 * init_lr * (1 + math.cos(math.pi * t_cur / t_total)) lr = 0.5 * init_lr * (1 + math.cos(math.pi * t_cur / t_total))
else: else:
raise ValueError("do not support: {}".format(cfg.SEARCH.LOSS_FUN))
raise ValueError("do not support: {}".format(cfg.OPTIM.LR_POLICY))
return lr return lr



def _warmup_adjust_learning_rate( def _warmup_adjust_learning_rate(
init_lr, n_epochs, epoch, n_iter, iter=0, warmup_lr=0 init_lr, n_epochs, epoch, n_iter, iter=0, warmup_lr=0
): ):
@@ -102,6 +103,7 @@ def _warmup_adjust_learning_rate(
new_lr = T_cur / t_total * (init_lr - warmup_lr) + warmup_lr new_lr = T_cur / t_total * (init_lr - warmup_lr) + warmup_lr
return new_lr return new_lr



def adjust_learning_rate_per_batch(epoch, n_iter=None, iter=0, warmup=False): def adjust_learning_rate_per_batch(epoch, n_iter=None, iter=0, warmup=False):
"""adjust learning of a given optimizer and return the new learning rate""" """adjust learning of a given optimizer and return the new learning rate"""


+ 7
- 7
xnas/runner/trainer.py View File

@@ -114,9 +114,9 @@ class Trainer(Recorder):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic() self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats # Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch) self.test_meter.log_epoch_stats(cur_epoch)
self.test_meter.reset() self.test_meter.reset()
@@ -349,9 +349,9 @@ class OneShotTrainer(Trainer):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic() self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats # Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch) self.test_meter.log_epoch_stats(cur_epoch)
self.test_meter.reset() self.test_meter.reset()
@@ -372,7 +372,7 @@ class OneShotTrainer(Trainer):
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item() top1_err, top5_err = top1_err.item(), top5_err.item()
self.evaluate_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.evaluate_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
top1_err = self.evaluate_meter.mb_top1_err.get_win_median()
top1_err = self.evaluate_meter.mb_top1_err.get_win_avg()
# self.evaluate_sampler.record(choice, top1_err) # self.evaluate_sampler.record(choice, top1_err)
self.evaluate_meter.reset() self.evaluate_meter.reset()
return top1_err return top1_err


+ 8
- 9
xnas/runner/trainer_spos.py View File

@@ -114,9 +114,9 @@ class Trainer(Recorder):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic() self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats # Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch) self.test_meter.log_epoch_stats(cur_epoch)
self.test_meter.reset() self.test_meter.reset()
@@ -381,10 +381,9 @@ class OneShotTrainer(Trainer):
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter) self.test_meter.log_iter_stats(cur_epoch, cur_iter)
self.test_meter.iter_tic() self.test_meter.iter_tic()
top1_err = self.test_meter.mb_top1_err.get_win_median()
top1_err_avg = self.test_meter.mb_top1_err.get_global_avg()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_median(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_median(), cur_epoch)
top1_err = self.test_meter.mb_top1_err.get_win_avg()
self.writer.add_scalar('val/top1_error', self.test_meter.mb_top1_err.get_win_avg(), cur_epoch)
self.writer.add_scalar('val/top5_error', self.test_meter.mb_top5_err.get_win_avg(), cur_epoch)
# Log epoch stats # Log epoch stats
self.test_meter.log_epoch_stats(cur_epoch) self.test_meter.log_epoch_stats(cur_epoch)
self.test_meter.reset() self.test_meter.reset()
@@ -392,7 +391,7 @@ class OneShotTrainer(Trainer):
if self.best_err > top1_err: if self.best_err > top1_err:
self.best_err = top1_err self.best_err = top1_err
self.saving(cur_epoch, best=True) self.saving(cur_epoch, best=True)
return top1_err_avg
return top1_err
@torch.no_grad() @torch.no_grad()
def evaluate_epoch(self, sample): def evaluate_epoch(self, sample):
@@ -405,7 +404,7 @@ class OneShotTrainer(Trainer):
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item() top1_err, top5_err = top1_err.item(), top5_err.item()
self.evaluate_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.evaluate_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
top1_err = self.evaluate_meter.mb_top1_err.get_win_median()
top1_err = self.evaluate_meter.mb_top1_err.get_win_avg()
# self.evaluate_sampler.record(choice, top1_err) # self.evaluate_sampler.record(choice, top1_err)
self.evaluate_meter.reset() self.evaluate_meter.reset()
return top1_err return top1_err


+ 652
- 0
xnas/spaces/AttentiveNAS/cnn.py View File

@@ -0,0 +1,652 @@
# Implementation adapted from AttentiveNAS: https://github.com/facebookresearch/AttentiveNAS

import random
from copy import deepcopy
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from xnas.spaces.OFA.ops import ResidualBlock
from xnas.spaces.OFA.dynamic_ops import DynamicLinearLayer
from xnas.spaces.OFA.utils import val2list, make_divisible
from xnas.spaces.BigNAS.dynamic_layers import DynamicMBConvLayer, DynamicConvLayer, DynamicShortcutLayer


class AttentiveNasStaticModel(nn.Module):

def __init__(self, first_conv, blocks, last_conv, classifier, resolution, use_v3_head=True):
super(AttentiveNasStaticModel, self).__init__()
self.first_conv = first_conv
self.blocks = nn.ModuleList(blocks)
self.last_conv = last_conv
self.classifier = classifier

self.resolution = resolution #input size
self.use_v3_head = use_v3_head

def forward(self, x):
# resize input to target resolution first
# Rule: transform images into different sizes
if x.size(-1) != self.resolution:
x = F.interpolate(x, size=self.resolution, mode='bicubic')

x = self.first_conv(x)
for block in self.blocks:
x = block(x)
x = self.last_conv(x)
if not self.use_v3_head:
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
x = torch.squeeze(x)
x = self.classifier(x)
return x


@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
for block in self.blocks:
_str += block.module_str + '\n'
#_str += self.last_conv.module_str + '\n'
_str += self.classifier.module_str
return _str

@property
def config(self):
return {
'name': AttentiveNasStaticModel.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
#'last_conv': self.last_conv.config,
'classifier': self.classifier.config,
'resolution': self.resolution
}


def weight_initialization(self):
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)

@staticmethod
def build_from_config(config):
raise NotImplementedError

def set_bn_param(self, momentum, eps):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
if momentum is not None:
m.momentum = float(momentum)
else:
m.momentum = None
m.eps = float(eps)
return

def get_bn_param(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
return {
'momentum': m.momentum,
'eps': m.eps,
}
return None

def reset_running_stats_for_calibration(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
m.training = True
m.momentum = None # cumulative moving average
m.reset_running_stats()


class AttentiveNasDynamicModel(nn.Module):

def __init__(self, supernet_cfg, n_classes=1000, bn_param=(0., 1e-5)):
super(AttentiveNasDynamicModel, self).__init__()

self.supernet_cfg = supernet_cfg
self.n_classes = n_classes
self.use_v3_head = getattr(self.supernet_cfg, 'use_v3_head', False)
self.stage_names = ['first_conv', 'mb1', 'mb2', 'mb3', 'mb4', 'mb5', 'mb6', 'mb7', 'last_conv']

self.width_list, self.depth_list, self.ks_list, self.expand_ratio_list = [], [], [], []
for name in self.stage_names:
block_cfg = getattr(self.supernet_cfg, name)
self.width_list.append(block_cfg.c)
if name.startswith('mb'):
self.depth_list.append(block_cfg.d)
self.ks_list.append(block_cfg.k)
self.expand_ratio_list.append(block_cfg.t)
self.resolution_list = self.supernet_cfg.resolutions

self.cfg_candidates = {
'resolution': self.resolution_list,
'width': self.width_list,
'depth': self.depth_list,
'kernel_size': self.ks_list,
'expand_ratio': self.expand_ratio_list
}

#first conv layer, including conv, bn, act
out_channel_list, act_func, stride = \
self.supernet_cfg.first_conv.c, self.supernet_cfg.first_conv.act_func, self.supernet_cfg.first_conv.s
self.first_conv = DynamicConvLayer(
in_channel_list=val2list(3), out_channel_list=out_channel_list,
kernel_size=3, stride=stride, act_func=act_func,
)

# inverted residual blocks
self.block_group_info = []
blocks = []
_block_index = 0
feature_dim = out_channel_list
for stage_id, key in enumerate(self.stage_names[1:-1]):
block_cfg = getattr(self.supernet_cfg, key)
width = block_cfg.c
n_block = max(block_cfg.d)
act_func = block_cfg.act_func
ks = block_cfg.k
expand_ratio_list = block_cfg.t
use_se = block_cfg.se

self.block_group_info.append([_block_index + i for i in range(n_block)])
_block_index += n_block

output_channel = width
for i in range(n_block):
stride = block_cfg.s if i == 0 else 1
if min(expand_ratio_list) >= 4:
expand_ratio_list = [_s for _s in expand_ratio_list if _s >= 4] if i == 0 else expand_ratio_list
mobile_inverted_conv = DynamicMBConvLayer(
in_channel_list=feature_dim,
out_channel_list=output_channel,
kernel_size_list=ks,
expand_ratio_list=expand_ratio_list,
stride=stride,
act_func=act_func,
use_se=use_se,
channels_per_group=getattr(self.supernet_cfg, 'channels_per_group', 1)
)
# Rule: add skip-connect, and use 2x2 AvgPool or 1x1 Conv for adaptation
shortcut = DynamicShortcutLayer(feature_dim, output_channel, reduction=stride)
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
feature_dim = output_channel
self.blocks = nn.ModuleList(blocks)

last_channel, act_func = self.supernet_cfg.last_conv.c, self.supernet_cfg.last_conv.act_func
if not self.use_v3_head:
self.last_conv = DynamicConvLayer(
in_channel_list=feature_dim, out_channel_list=last_channel,
kernel_size=1, act_func=act_func,
)
else:
expand_feature_dim = [f_dim * 6 for f_dim in feature_dim]
self.last_conv = nn.Sequential(OrderedDict([
('final_expand_layer', DynamicConvLayer(
feature_dim, expand_feature_dim, kernel_size=1, use_bn=True, act_func=act_func)
),
('pool', nn.AdaptiveAvgPool2d((1,1))),
('feature_mix_layer', DynamicConvLayer(
in_channel_list=expand_feature_dim, out_channel_list=last_channel,
kernel_size=1, act_func=act_func, use_bn=False,)
),
]))

#final conv layer
self.classifier = DynamicLinearLayer(
in_features_list=last_channel, out_features=n_classes, bias=True
)

# set bn param
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])

# runtime_depth
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]

self.zero_residual_block_bn_weights()

self.active_dropout_rate = 0
self.active_drop_connect_rate = 0
self.active_resolution = 224

# Rule: Initialize learnable coefficient \gamma=0
def zero_residual_block_bn_weights(self):
with torch.no_grad():
for m in self.modules():
if isinstance(m, ResidualBlock):
if isinstance(m.mobile_inverted_conv, DynamicMBConvLayer) and m.shortcut is not None:
m.mobile_inverted_conv.point_linear.bn.bn.weight.zero_()

@staticmethod
def name():
return 'AttentiveNasModel'

def forward(self, x):
# resize input to target resolution first
if x.size(-1) != self.active_resolution:
x = F.interpolate(x, size=self.active_resolution, mode='bicubic')

# first conv
x = self.first_conv(x)
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
x = self.blocks[idx](x)

x = self.last_conv(x)
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
x = torch.squeeze(x)

if self.active_dropout_rate > 0 and self.training:
x = F.dropout(x, p = self.active_dropout_rate)

x = self.classifier(x)
return x


@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
_str += self.blocks[0].module_str + '\n'

for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
_str += self.blocks[idx].module_str + '\n'
if not self.use_v3_head:
_str += self.last_conv.module_str + '\n'
else:
_str += self.last_conv.final_expand_layer.module_str + '\n'
_str += self.last_conv.feature_mix_layer.module_str + '\n'
_str += self.classifier.module_str + '\n'
return _str

@property
def config(self):
return {
'name': AttentiveNasDynamicModel.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
'last_conv': self.last_conv.config if not self.use_v3_head else None,
'final_expand_layer': self.last_conv.final_expand_layer if self.use_v3_head else None,
'feature_mix_layer': self.last_conv.feature_mix_layer if self.use_v3_head else None,
'classifier': self.classifier.config,
'resolution': self.active_resolution
}


@staticmethod
def build_from_config(config):
raise NotImplementedError

def get_parameters(self, keys=None, mode="include"):
if keys is None:
for name, param in self.named_parameters():
if param.requires_grad:
yield param
elif mode == "include":
for name, param in self.named_parameters():
flag = False
for key in keys:
if key in name:
flag = True
break
if flag and param.requires_grad:
yield param
elif mode == "exclude":
for name, param in self.named_parameters():
flag = True
for key in keys:
if key in name:
flag = False
break
if flag and param.requires_grad:
yield param
else:
raise ValueError("do not support: %s" % mode)

def set_bn_param(self, momentum, eps):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
if momentum is not None:
m.momentum = float(momentum)
else:
m.momentum = None
m.eps = float(eps)
return

def get_bn_param(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
return {
'momentum': m.momentum,
'eps': m.eps,
}
return None

""" set, sample and get active sub-networks """
def set_active_subnet(self, resolution=224, width=None, depth=None, kernel_size=None, expand_ratio=None, **kwargs):
assert len(depth) == len(kernel_size) == len(expand_ratio) == len(width) - 2
#set resolution
self.active_resolution = resolution

# first conv
self.first_conv.active_out_channel = width[0]

for stage_id, (c, k, e, d) in enumerate(zip(width[1:-1], kernel_size, expand_ratio, depth)):
start_idx, end_idx = min(self.block_group_info[stage_id]), max(self.block_group_info[stage_id])
for block_id in range(start_idx, start_idx+d):
block = self.blocks[block_id]
#block output channels
block.mobile_inverted_conv.active_out_channel = c
if block.shortcut is not None:
block.shortcut.active_out_channel = c

#dw kernel size
block.mobile_inverted_conv.active_kernel_size = k

#dw expansion ration
block.mobile_inverted_conv.active_expand_ratio = e

#IRBlocks repated times
for i, d in enumerate(depth):
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)

#last conv
if not self.use_v3_head:
self.last_conv.active_out_channel = width[-1]
else:
# default expansion ratio: 6
self.last_conv.final_expand_layer.active_out_channel = width[-2] * 6
self.last_conv.feature_mix_layer.active_out_channel = width[-1]

def get_active_subnet_settings(self):
r = self.active_resolution
width, depth, kernel_size, expand_ratio= [], [], [], []

#first conv
width.append(self.first_conv.active_out_channel)
for stage_id in range(len(self.block_group_info)):
start_idx = min(self.block_group_info[stage_id])
block = self.blocks[start_idx] #first block
width.append(block.mobile_inverted_conv.active_out_channel)
kernel_size.append(block.mobile_inverted_conv.active_kernel_size)
expand_ratio.append(block.mobile_inverted_conv.active_expand_ratio)
depth.append(self.runtime_depth[stage_id])
if not self.use_v3_head:
width.append(self.last_conv.active_out_channel)
else:
width.append(self.last_conv.feature_mix_layer.active_out_channel)

return {
'resolution': r,
'width': width,
'kernel_size': kernel_size,
'expand_ratio': expand_ratio,
'depth': depth,
}

def set_dropout_rate(self, dropout=0, drop_connect=0, drop_connect_only_last_two_stages=True):
self.active_dropout_rate = dropout
for idx, block in enumerate(self.blocks):
if drop_connect_only_last_two_stages:
if idx not in self.block_group_info[-1] + self.block_group_info[-2]:
continue
this_drop_connect_rate = drop_connect * float(idx) / len(self.blocks)
block.drop_connect_rate = this_drop_connect_rate


def sample_min_subnet(self):
return self._sample_active_subnet(min_net=True)


def sample_max_subnet(self):
return self._sample_active_subnet(max_net=True)

def sample_active_subnet(self, compute_flops=False):
cfg = self._sample_active_subnet(
False, False
)
if compute_flops:
cfg['flops'] = self.compute_active_subnet_flops()
return cfg

def sample_active_subnet_within_range(self, targeted_min_flops, targeted_max_flops):
while True:
cfg = self._sample_active_subnet()
cfg['flops'] = self.compute_active_subnet_flops()
if cfg['flops'] >= targeted_min_flops and cfg['flops'] <= targeted_max_flops:
return cfg

def _sample_active_subnet(self, min_net=False, max_net=False):

sample_cfg = lambda candidates, sample_min, sample_max: \
min(candidates) if sample_min else (max(candidates) if sample_max else random.choice(candidates))

cfg = {}
# sample a resolution
cfg['resolution'] = sample_cfg(self.cfg_candidates['resolution'], min_net, max_net)
for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
cfg[k] = []
for vv in self.cfg_candidates[k]:
cfg[k].append(sample_cfg(val2list(vv), min_net, max_net))

self.set_active_subnet(
cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio']
)
return cfg


def mutate_and_reset(self, cfg, prob=0.1, keep_resolution=False):
cfg = deepcopy(cfg)
pick_another = lambda x, candidates: x if len(candidates) == 1 else random.choice([v for v in candidates if v != x])
# sample a resolution
r = random.random()
if r < prob and not keep_resolution:
cfg['resolution'] = pick_another(cfg['resolution'], self.cfg_candidates['resolution'])

# sample channels, depth, kernel_size, expand_ratio
for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
for _i, _v in enumerate(cfg[k]):
r = random.random()
if r < prob:
cfg[k][_i] = pick_another(cfg[k][_i], val2list(self.cfg_candidates[k][_i]))

self.set_active_subnet(
cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio']
)
return cfg


def crossover_and_reset(self, cfg1, cfg2, p=0.5):
def _cross_helper(g1, g2, prob):
assert type(g1) == type(g2)
if isinstance(g1, int):
return g1 if random.random() < prob else g2
elif isinstance(g1, list):
return [v1 if random.random() < prob else v2 for v1, v2 in zip(g1, g2)]
else:
raise NotImplementedError

cfg = {}
cfg['resolution'] = cfg1['resolution'] if random.random() < p else cfg2['resolution']
for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
cfg[k] = _cross_helper(cfg1[k], cfg2[k], p)

self.set_active_subnet(
cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio']
)
return cfg


def get_active_subnet(self, preserve_weight=True):
with torch.no_grad():
first_conv = self.first_conv.get_active_subnet(3, preserve_weight)

blocks = []
input_channel = first_conv.out_channels
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append(ResidualBlock(
self.blocks[idx].mobile_inverted_conv.get_active_subnet(input_channel, preserve_weight),
self.blocks[idx].shortcut.get_active_subnet(input_channel, preserve_weight) if self.blocks[idx].shortcut is not None else None
))
input_channel = stage_blocks[-1].mobile_inverted_conv.out_channels
blocks += stage_blocks

if not self.use_v3_head:
last_conv = self.last_conv.get_active_subnet(input_channel, preserve_weight)
in_features = last_conv.out_channels
else:
final_expand_layer = self.last_conv.final_expand_layer.get_active_subnet(input_channel, preserve_weight)
feature_mix_layer = self.last_conv.feature_mix_layer.get_active_subnet(input_channel*6, preserve_weight)
in_features = feature_mix_layer.out_channels
last_conv = nn.Sequential(
final_expand_layer,
nn.AdaptiveAvgPool2d((1,1)),
feature_mix_layer
)

classifier = self.classifier.get_active_subnet(in_features, preserve_weight)

_subnet = AttentiveNasStaticModel(
first_conv, blocks, last_conv, classifier, self.active_resolution, use_v3_head=self.use_v3_head
)
_subnet.set_bn_param(**self.get_bn_param())
return _subnet


def compute_active_subnet_flops(self):

def count_conv(c_in, c_out, size_out, groups, k):
kernel_ops = k**2
output_elements = c_out * size_out**2
ops = c_in * output_elements * kernel_ops / groups
return ops

def count_linear(c_in, c_out):
return c_in * c_out

total_ops = 0

c_in = 3
size_out = self.active_resolution // self.first_conv.stride
c_out = self.first_conv.active_out_channel

total_ops += count_conv(c_in, c_out, size_out, 1, 3)
c_in = c_out

# mb blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
block = self.blocks[idx]
c_middle = make_divisible(round(c_in * block.mobile_inverted_conv.active_expand_ratio), 8)
# 1*1 conv
if block.mobile_inverted_conv.inverted_bottleneck is not None:
total_ops += count_conv(c_in, c_middle, size_out, 1, 1)
# dw conv
stride = 1 if idx > active_idx[0] else block.mobile_inverted_conv.stride
if size_out % stride == 0:
size_out = size_out // stride
else:
size_out = (size_out +1) // stride
total_ops += count_conv(c_middle, c_middle, size_out, c_middle, block.mobile_inverted_conv.active_kernel_size)
# 1*1 conv
c_out = block.mobile_inverted_conv.active_out_channel
total_ops += count_conv(c_middle, c_out, size_out, 1, 1)
#se
if block.mobile_inverted_conv.use_se:
num_mid = make_divisible(c_middle // block.mobile_inverted_conv.depth_conv.se.reduction, divisor=8)
total_ops += count_conv(c_middle, num_mid, 1, 1, 1) * 2
if block.shortcut and c_in != c_out:
total_ops += count_conv(c_in, c_out, size_out, 1, 1)
c_in = c_out

if not self.use_v3_head:
c_out = self.last_conv.active_out_channel
total_ops += count_conv(c_in, c_out, size_out, 1, 1)
else:
c_expand = self.last_conv.final_expand_layer.active_out_channel
c_out = self.last_conv.feature_mix_layer.active_out_channel
total_ops += count_conv(c_in, c_expand, size_out, 1, 1)
total_ops += count_conv(c_expand, c_out, 1, 1, 1)

# n_classes
total_ops += count_linear(c_out, self.n_classes)
return total_ops / 1e6


def load_weights_from_pretrained_models(self, checkpoint_path):
with open(checkpoint_path, 'rb') as f:
checkpoint = torch.load(f, map_location='cpu')
assert isinstance(checkpoint, dict)
pretrained_state_dicts = checkpoint['state_dict']
for k, v in self.state_dict().items():
name = 'module.' + k if not k.startswith('module') else k
v.copy_(pretrained_state_dicts[name])


def _AttentiveNAS_CNN():
from xnas.core.config import cfg
bn_momentum = cfg.ATTENTIVENAS.BN_MOMENTUM
bn_eps = cfg.ATTENTIVENAS.BN_EPS
return AttentiveNasDynamicModel(
cfg.ATTENTIVENAS.SUPERNET_CFG,
cfg.LOADER.NUM_CLASSES,
(bn_momentum, bn_eps),
)

def _infer_AttentiveNAS_CNN():
from xnas.core.config import cfg
bn_momentum = cfg.ATTENTIVENAS.BN_MOMENTUM
bn_eps = cfg.ATTENTIVENAS.BN_EPS
supernet = AttentiveNasDynamicModel(
cfg.ATTENTIVENAS.SUPERNET_CFG,
cfg.LOADER.NUM_CLASSES,
(bn_momentum, bn_eps),
)
# namespace changed: pareto_models.supernet_checkpoint_path
supernet.load_weights_from_pretrained_models(cfg.SEARCH.WEIGHT)
# namespace created: active_subnet.*
supernet.set_active_subnet(
resolution=cfg.ATTENTIVENAS.ACTIVE_SUBNET.RESOLUTION,
width = cfg.ATTENTIVENAS.ACTIVE_SUBNET.WIDTH,
depth = cfg.ATTENTIVENAS.ACTIVE_SUBNET.DEPTH,
kernel_size = cfg.ATTENTIVENAS.ACTIVE_SUBNET.KERNEL_SIZE,
expand_ratio = cfg.ATTENTIVENAS.ACTIVE_SUBNET.EXPAND_RATIO,
)
model = supernet.get_active_subnet()
# house-keeping stuff: may using different values with supernet
model.set_bn_param(momentum=bn_momentum, eps=bn_eps)
del supernet
return model

+ 653
- 0
xnas/spaces/BigNAS/cnn.py View File

@@ -0,0 +1,653 @@
# Implementation adapted from AttentiveNAS: https://github.com/facebookresearch/AttentiveNAS

import random
from copy import deepcopy
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from xnas.spaces.OFA.ops import ResidualBlock
from xnas.spaces.OFA.dynamic_ops import DynamicLinearLayer
from xnas.spaces.OFA.utils import val2list, make_divisible
from xnas.spaces.BigNAS.dynamic_layers import DynamicMBConvLayer, DynamicConvLayer, DynamicShortcutLayer


class BigNASStaticModel(nn.Module):

def __init__(self, first_conv, blocks, last_conv, classifier, resolution, use_v3_head=True):
super(BigNASStaticModel, self).__init__()
self.first_conv = first_conv
self.blocks = nn.ModuleList(blocks)
self.last_conv = last_conv
self.classifier = classifier

self.resolution = resolution #input size
self.use_v3_head = use_v3_head

def forward(self, x):
# resize input to target resolution first
# Rule: transform images into different sizes
if x.size(-1) != self.resolution:
x = F.interpolate(x, size=self.resolution, mode='bicubic')

x = self.first_conv(x)
for block in self.blocks:
x = block(x)
x = self.last_conv(x)
if not self.use_v3_head:
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
x = torch.squeeze(x)
x = self.classifier(x)
return x


@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
for block in self.blocks:
_str += block.module_str + '\n'
#_str += self.last_conv.module_str + '\n'
_str += self.classifier.module_str
return _str

@property
def config(self):
return {
'name': BigNASStaticModel.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
#'last_conv': self.last_conv.config,
'classifier': self.classifier.config,
'resolution': self.resolution
}


def weight_initialization(self):
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)

@staticmethod
def build_from_config(config):
raise NotImplementedError

def set_bn_param(self, momentum, eps):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
if momentum is not None:
m.momentum = float(momentum)
else:
m.momentum = None
m.eps = float(eps)
return

def get_bn_param(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
return {
'momentum': m.momentum,
'eps': m.eps,
}
return None

def reset_running_stats_for_calibration(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
m.training = True
m.momentum = None # cumulative moving average
m.reset_running_stats()


class BigNASDynamicModel(nn.Module):

def __init__(self, supernet_cfg, n_classes=1000, bn_param=(0., 1e-5)):
super(BigNASDynamicModel, self).__init__()

self.supernet_cfg = supernet_cfg
self.n_classes = n_classes
self.use_v3_head = getattr(self.supernet_cfg, 'use_v3_head', False)
self.stage_names = ['first_conv', 'mb1', 'mb2', 'mb3', 'mb4', 'mb5', 'mb6', 'mb7', 'last_conv']

self.width_list, self.depth_list, self.ks_list, self.expand_ratio_list = [], [], [], []
for name in self.stage_names:
block_cfg = getattr(self.supernet_cfg, name)
self.width_list.append(block_cfg.c)
if name.startswith('mb'):
self.depth_list.append(block_cfg.d)
self.ks_list.append(block_cfg.k)
self.expand_ratio_list.append(block_cfg.t)
self.resolution_list = self.supernet_cfg.resolutions

self.cfg_candidates = {
'resolution': self.resolution_list ,
'width': self.width_list,
'depth': self.depth_list,
'kernel_size': self.ks_list,
'expand_ratio': self.expand_ratio_list
}

#first conv layer, including conv, bn, act
out_channel_list, act_func, stride = \
self.supernet_cfg.first_conv.c, self.supernet_cfg.first_conv.act_func, self.supernet_cfg.first_conv.s
self.first_conv = DynamicConvLayer(
in_channel_list=val2list(3), out_channel_list=out_channel_list,
kernel_size=3, stride=stride, act_func=act_func,
)

# inverted residual blocks
self.block_group_info = []
blocks = []
_block_index = 0
feature_dim = out_channel_list
for stage_id, key in enumerate(self.stage_names[1:-1]):
block_cfg = getattr(self.supernet_cfg, key)
width = block_cfg.c
n_block = max(block_cfg.d)
act_func = block_cfg.act_func
ks = block_cfg.k
expand_ratio_list = block_cfg.t
use_se = block_cfg.se

self.block_group_info.append([_block_index + i for i in range(n_block)])
_block_index += n_block

output_channel = width
for i in range(n_block):
stride = block_cfg.s if i == 0 else 1
if min(expand_ratio_list) >= 4:
expand_ratio_list = [_s for _s in expand_ratio_list if _s >= 4] if i == 0 else expand_ratio_list
mobile_inverted_conv = DynamicMBConvLayer(
in_channel_list=feature_dim,
out_channel_list=output_channel,
kernel_size_list=ks,
expand_ratio_list=expand_ratio_list,
stride=stride,
act_func=act_func,
use_se=use_se,
channels_per_group=getattr(self.supernet_cfg, 'channels_per_group', 1)
)
# Rule: add skip-connect, and use 2x2 AvgPool or 1x1 Conv for adaptation
shortcut = DynamicShortcutLayer(feature_dim, output_channel, reduction=stride)
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
feature_dim = output_channel
self.blocks = nn.ModuleList(blocks)

last_channel, act_func = self.supernet_cfg.last_conv.c, self.supernet_cfg.last_conv.act_func
if not self.use_v3_head:
self.last_conv = DynamicConvLayer(
in_channel_list=feature_dim, out_channel_list=last_channel,
kernel_size=1, act_func=act_func,
)
else:
expand_feature_dim = [f_dim * 6 for f_dim in feature_dim]
self.last_conv = nn.Sequential(OrderedDict([
('final_expand_layer', DynamicConvLayer(
feature_dim, expand_feature_dim, kernel_size=1, use_bn=True, act_func=act_func)
),
('pool', nn.AdaptiveAvgPool2d((1,1))),
('feature_mix_layer', DynamicConvLayer(
in_channel_list=expand_feature_dim, out_channel_list=last_channel,
kernel_size=1, act_func=act_func, use_bn=False,)
),
]))

#final conv layer
self.classifier = DynamicLinearLayer(
in_features_list=last_channel, out_features=n_classes, bias=True
)

# set bn param
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])

# runtime_depth
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]

self.zero_residual_block_bn_weights()

self.active_dropout_rate = 0
self.active_drop_connect_rate = 0
self.active_resolution = 224

# Rule: Initialize learnable coefficient \gamma=0
def zero_residual_block_bn_weights(self):
with torch.no_grad():
for m in self.modules():
if isinstance(m, ResidualBlock):
if isinstance(m.mobile_inverted_conv, DynamicMBConvLayer) and m.shortcut is not None:
m.mobile_inverted_conv.point_linear.bn.bn.weight.zero_()

@staticmethod
def name():
return 'BigNASDynamicModel'

def forward(self, x):
# resize input to target resolution first
if x.size(-1) != self.active_resolution:
x = F.interpolate(x, size=self.active_resolution, mode='bicubic')

# first conv
x = self.first_conv(x)
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
x = self.blocks[idx](x)

x = self.last_conv(x)
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
x = torch.squeeze(x)

if self.active_dropout_rate > 0 and self.training:
x = F.dropout(x, p = self.active_dropout_rate)

x = self.classifier(x)
return x


@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
_str += self.blocks[0].module_str + '\n'

for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
_str += self.blocks[idx].module_str + '\n'
if not self.use_v3_head:
_str += self.last_conv.module_str + '\n'
else:
_str += self.last_conv.final_expand_layer.module_str + '\n'
_str += self.last_conv.feature_mix_layer.module_str + '\n'
_str += self.classifier.module_str + '\n'
return _str

@property
def config(self):
return {
'name': BigNASDynamicModel.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
'last_conv': self.last_conv.config if not self.use_v3_head else None,
'final_expand_layer': self.last_conv.final_expand_layer if self.use_v3_head else None,
'feature_mix_layer': self.last_conv.feature_mix_layer if self.use_v3_head else None,
'classifier': self.classifier.config,
'resolution': self.active_resolution
}


@staticmethod
def build_from_config(config):
raise NotImplementedError

def get_parameters(self, keys=None, mode="include"):
if keys is None:
for name, param in self.named_parameters():
if param.requires_grad:
yield param
elif mode == "include":
for name, param in self.named_parameters():
flag = False
for key in keys:
if key in name:
flag = True
break
if flag and param.requires_grad:
yield param
elif mode == "exclude":
for name, param in self.named_parameters():
flag = True
for key in keys:
if key in name:
flag = False
break
if flag and param.requires_grad:
yield param
else:
raise ValueError("do not support: %s" % mode)

def set_bn_param(self, momentum, eps):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
if momentum is not None:
m.momentum = float(momentum)
else:
m.momentum = None
m.eps = float(eps)
return

def get_bn_param(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.SyncBatchNorm):
return {
'momentum': m.momentum,
'eps': m.eps,
}
return None

""" set, sample and get active sub-networks """
def set_active_subnet(self, resolution=224, width=None, depth=None, kernel_size=None, expand_ratio=None, **kwargs):
assert len(depth) == len(kernel_size) == len(expand_ratio) == len(width) - 2
#set resolution
self.active_resolution = resolution

# first conv
self.first_conv.active_out_channel = width[0]

for stage_id, (c, k, e, d) in enumerate(zip(width[1:-1], kernel_size, expand_ratio, depth)):
start_idx, end_idx = min(self.block_group_info[stage_id]), max(self.block_group_info[stage_id])
for block_id in range(start_idx, start_idx+d):
block = self.blocks[block_id]
#block output channels
block.mobile_inverted_conv.active_out_channel = c
if block.shortcut is not None:
block.shortcut.active_out_channel = c

#dw kernel size
block.mobile_inverted_conv.active_kernel_size = k

#dw expansion ration
block.mobile_inverted_conv.active_expand_ratio = e

#IRBlocks repated times
for i, d in enumerate(depth):
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)

#last conv
if not self.use_v3_head:
self.last_conv.active_out_channel = width[-1]
else:
# default expansion ratio: 6
self.last_conv.final_expand_layer.active_out_channel = width[-2] * 6
self.last_conv.feature_mix_layer.active_out_channel = width[-1]

def get_active_subnet_settings(self):
r = self.active_resolution
width, depth, kernel_size, expand_ratio= [], [], [], []

#first conv
width.append(self.first_conv.active_out_channel)
for stage_id in range(len(self.block_group_info)):
start_idx = min(self.block_group_info[stage_id])
block = self.blocks[start_idx] #first block
width.append(block.mobile_inverted_conv.active_out_channel)
kernel_size.append(block.mobile_inverted_conv.active_kernel_size)
expand_ratio.append(block.mobile_inverted_conv.active_expand_ratio)
depth.append(self.runtime_depth[stage_id])
if not self.use_v3_head:
width.append(self.last_conv.active_out_channel)
else:
width.append(self.last_conv.feature_mix_layer.active_out_channel)

return {
'resolution': r,
'width': width,
'kernel_size': kernel_size,
'expand_ratio': expand_ratio,
'depth': depth,
}

def set_dropout_rate(self, dropout=0, drop_connect=0, drop_connect_only_last_two_stages=True):
self.active_dropout_rate = dropout
for idx, block in enumerate(self.blocks):
if drop_connect_only_last_two_stages:
if idx not in self.block_group_info[-1] + self.block_group_info[-2]:
continue
this_drop_connect_rate = drop_connect * float(idx) / len(self.blocks)
block.drop_connect_rate = this_drop_connect_rate


def sample_min_subnet(self):
return self._sample_active_subnet(min_net=True)


def sample_max_subnet(self):
return self._sample_active_subnet(max_net=True)

def sample_active_subnet(self, compute_flops=False):
cfg = self._sample_active_subnet(
False, False
)
if compute_flops:
cfg['flops'] = self.compute_active_subnet_flops()
return cfg

def sample_active_subnet_within_range(self, targeted_min_flops, targeted_max_flops):
while True:
cfg = self._sample_active_subnet()
cfg['flops'] = self.compute_active_subnet_flops()
if cfg['flops'] >= targeted_min_flops and cfg['flops'] <= targeted_max_flops:
return cfg

def _sample_active_subnet(self, min_net=False, max_net=False):

sample_cfg = lambda candidates, sample_min, sample_max: \
min(candidates) if sample_min else (max(candidates) if sample_max else random.choice(candidates))

cfg = {}
# sample a resolution
cfg['resolution'] = sample_cfg(self.cfg_candidates['resolution'], min_net, max_net)
for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
cfg[k] = []
for vv in self.cfg_candidates[k]:
cfg[k].append(sample_cfg(val2list(vv), min_net, max_net))

self.set_active_subnet(
cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio']
)
return cfg


def mutate_and_reset(self, cfg, prob=0.1, keep_resolution=False):
cfg = deepcopy(cfg)
pick_another = lambda x, candidates: x if len(candidates) == 1 else random.choice([v for v in candidates if v != x])
# sample a resolution
r = random.random()
if r < prob and not keep_resolution:
cfg['resolution'] = pick_another(cfg['resolution'], self.cfg_candidates['resolution'])

# sample channels, depth, kernel_size, expand_ratio
for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
for _i, _v in enumerate(cfg[k]):
r = random.random()
if r < prob:
cfg[k][_i] = pick_another(cfg[k][_i], val2list(self.cfg_candidates[k][_i]))

self.set_active_subnet(
cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio']
)
return cfg


def crossover_and_reset(self, cfg1, cfg2, p=0.5):
def _cross_helper(g1, g2, prob):
assert type(g1) == type(g2)
if isinstance(g1, int):
return g1 if random.random() < prob else g2
elif isinstance(g1, list):
return [v1 if random.random() < prob else v2 for v1, v2 in zip(g1, g2)]
else:
raise NotImplementedError

cfg = {}
cfg['resolution'] = cfg1['resolution'] if random.random() < p else cfg2['resolution']
for k in ['width', 'depth', 'kernel_size', 'expand_ratio']:
cfg[k] = _cross_helper(cfg1[k], cfg2[k], p)

self.set_active_subnet(
cfg['resolution'], cfg['width'], cfg['depth'], cfg['kernel_size'], cfg['expand_ratio']
)
return cfg


def get_active_subnet(self, preserve_weight=True):
with torch.no_grad():
first_conv = self.first_conv.get_active_subnet(3, preserve_weight)

blocks = []
input_channel = first_conv.out_channels
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append(ResidualBlock(
self.blocks[idx].mobile_inverted_conv.get_active_subnet(input_channel, preserve_weight),
self.blocks[idx].shortcut.get_active_subnet(input_channel, preserve_weight) if self.blocks[idx].shortcut is not None else None
))
input_channel = stage_blocks[-1].mobile_inverted_conv.out_channels
blocks += stage_blocks

if not self.use_v3_head:
last_conv = self.last_conv.get_active_subnet(input_channel, preserve_weight)
in_features = last_conv.out_channels
else:
final_expand_layer = self.last_conv.final_expand_layer.get_active_subnet(input_channel, preserve_weight)
feature_mix_layer = self.last_conv.feature_mix_layer.get_active_subnet(input_channel*6, preserve_weight)
in_features = feature_mix_layer.out_channels
last_conv = nn.Sequential(
final_expand_layer,
nn.AdaptiveAvgPool2d((1,1)),
feature_mix_layer
)

classifier = self.classifier.get_active_subnet(in_features, preserve_weight)

_subnet = BigNASStaticModel(
first_conv, blocks, last_conv, classifier, self.active_resolution, use_v3_head=self.use_v3_head
)
_subnet.set_bn_param(**self.get_bn_param())
return _subnet


def compute_active_subnet_flops(self):

def count_conv(c_in, c_out, size_out, groups, k):
kernel_ops = k**2
output_elements = c_out * size_out**2
ops = c_in * output_elements * kernel_ops / groups
return ops

def count_linear(c_in, c_out):
return c_in * c_out

total_ops = 0

c_in = 3
size_out = self.active_resolution // self.first_conv.stride
c_out = self.first_conv.active_out_channel

total_ops += count_conv(c_in, c_out, size_out, 1, 3)
c_in = c_out

# mb blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
block = self.blocks[idx]
c_middle = make_divisible(round(c_in * block.mobile_inverted_conv.active_expand_ratio), 8)
# 1*1 conv
if block.mobile_inverted_conv.inverted_bottleneck is not None:
total_ops += count_conv(c_in, c_middle, size_out, 1, 1)
# dw conv
stride = 1 if idx > active_idx[0] else block.mobile_inverted_conv.stride
if size_out % stride == 0:
size_out = size_out // stride
else:
size_out = (size_out +1) // stride
total_ops += count_conv(c_middle, c_middle, size_out, c_middle, block.mobile_inverted_conv.active_kernel_size)
# 1*1 conv
c_out = block.mobile_inverted_conv.active_out_channel
total_ops += count_conv(c_middle, c_out, size_out, 1, 1)
#se
if block.mobile_inverted_conv.use_se:
num_mid = make_divisible(c_middle // block.mobile_inverted_conv.depth_conv.se.reduction, divisor=8)
total_ops += count_conv(c_middle, num_mid, 1, 1, 1) * 2
if block.shortcut and c_in != c_out:
total_ops += count_conv(c_in, c_out, size_out, 1, 1)
c_in = c_out

if not self.use_v3_head:
c_out = self.last_conv.active_out_channel
total_ops += count_conv(c_in, c_out, size_out, 1, 1)
else:
c_expand = self.last_conv.final_expand_layer.active_out_channel
c_out = self.last_conv.feature_mix_layer.active_out_channel
total_ops += count_conv(c_in, c_expand, size_out, 1, 1)
total_ops += count_conv(c_expand, c_out, 1, 1, 1)

# n_classes
total_ops += count_linear(c_out, self.n_classes)
return total_ops / 1e6


def load_weights_from_pretrained_models(self, checkpoint_path):
with open(checkpoint_path, 'rb') as f:
checkpoint = torch.load(f, map_location='cpu')
assert isinstance(checkpoint, dict)
pretrained_state_dicts = checkpoint['model_state']
for k, v in self.state_dict().items():
# name = 'module.' + k if not k.startswith('module') else k
name = k
v.copy_(pretrained_state_dicts[name])


def _BigNAS_CNN():
from xnas.core.config import cfg
bn_momentum = cfg.BIGNAS.BN_MOMENTUM
bn_eps = cfg.BIGNAS.BN_EPS
return BigNASDynamicModel(
cfg.BIGNAS.SUPERNET_CFG,
cfg.LOADER.NUM_CLASSES,
(bn_momentum, bn_eps),
)

def _infer_BigNAS_CNN():
from xnas.core.config import cfg
bn_momentum = cfg.BIGNAS.BN_MOMENTUM
bn_eps = cfg.BIGNAS.BN_EPS
supernet = BigNASDynamicModel(
cfg.BIGNAS.SUPERNET_CFG,
cfg.LOADER.NUM_CLASSES,
(bn_momentum, bn_eps),
)
# namespace changed: pareto_models.supernet_checkpoint_path
supernet.load_weights_from_pretrained_models(cfg.SEARCH.WEIGHT)
# namespace created: active_subnet.*
supernet.set_active_subnet(
resolution=cfg.BIGNAS.ACTIVE_SUBNET.RESOLUTION,
width = cfg.BIGNAS.ACTIVE_SUBNET.WIDTH,
depth = cfg.BIGNAS.ACTIVE_SUBNET.DEPTH,
kernel_size = cfg.BIGNAS.ACTIVE_SUBNET.KERNEL_SIZE,
expand_ratio = cfg.BIGNAS.ACTIVE_SUBNET.EXPAND_RATIO,
)
model = supernet.get_active_subnet()
# house-keeping stuff: may using different values with supernet
model.set_bn_param(momentum=bn_momentum, eps=bn_eps)
del supernet
return model

+ 331
- 0
xnas/spaces/BigNAS/dynamic_layers.py View File

@@ -0,0 +1,331 @@
from collections import OrderedDict

import torch.nn as nn
import torch.nn.functional as F


from xnas.spaces.OFA.utils import val2list
from xnas.spaces.OFA.ops import SEModule, ConvLayer, ShortcutLayer, build_activation, make_divisible
from xnas.spaces.OFA.dynamic_ops import DynamicConv2d, DynamicSE, copy_bn
from xnas.spaces.BigNAS.ops import MBConvLayer
from xnas.spaces.BigNAS.dynamic_ops import DynamicSeparableConv2d, DynamicBatchNorm2d


class DynamicMBConvLayer(nn.Module):
def __init__(self, in_channel_list, out_channel_list,
kernel_size_list=3, expand_ratio_list=6, stride=1, act_func='relu6', use_se=False, channels_per_group=1):
super(DynamicMBConvLayer, self).__init__()
self.in_channel_list = val2list(in_channel_list)
self.out_channel_list = val2list(out_channel_list)
self.kernel_size_list = val2list(kernel_size_list, 1)
self.expand_ratio_list = val2list(expand_ratio_list, 1)
self.stride = stride
self.act_func = act_func
self.use_se = use_se
self.channels_per_group = channels_per_group
# build modules
max_middle_channel = round(max(self.in_channel_list) * max(self.expand_ratio_list))
if max(self.expand_ratio_list) == 1:
self.inverted_bottleneck = None
else:
self.inverted_bottleneck = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max(self.in_channel_list), max_middle_channel)),
('bn', DynamicBatchNorm2d(max_middle_channel)),
('act', build_activation(self.act_func, inplace=True)),
]))
self.depth_conv = nn.Sequential(OrderedDict([
('conv', DynamicSeparableConv2d(max_middle_channel, self.kernel_size_list, stride=self.stride, channels_per_group=self.channels_per_group)),
('bn', DynamicBatchNorm2d(max_middle_channel)),
('act', build_activation(self.act_func, inplace=True))
]))
if self.use_se:
self.depth_conv.add_module('se', DynamicSE(max_middle_channel))
self.point_linear = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max_middle_channel, max(self.out_channel_list))),
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
]))
self.active_kernel_size = max(self.kernel_size_list)
self.active_expand_ratio = max(self.expand_ratio_list)
self.active_out_channel = max(self.out_channel_list)
def forward(self, x):
in_channel = x.size(1)
if self.inverted_bottleneck is not None:
self.inverted_bottleneck.conv.active_out_channel = \
make_divisible(round(in_channel * self.active_expand_ratio), 8)

self.depth_conv.conv.active_kernel_size = self.active_kernel_size
self.point_linear.conv.active_out_channel = self.active_out_channel
if self.inverted_bottleneck is not None:
x = self.inverted_bottleneck(x)
x = self.depth_conv(x)
x = self.point_linear(x)
return x
@property
def module_str(self):
if self.use_se:
return 'SE(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)
else:
return '(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)
@property
def config(self):
return {
'name': DynamicMBConvLayer.__name__,
'in_channel_list': self.in_channel_list,
'out_channel_list': self.out_channel_list,
'kernel_size_list': self.kernel_size_list,
'expand_ratio_list': self.expand_ratio_list,
'stride': self.stride,
'act_func': self.act_func,
'use_se': self.use_se,
'channels_per_group': self.channels_per_group,
}
@staticmethod
def build_from_config(config):
return DynamicMBConvLayer(**config)

############################################################################################

def get_active_subnet(self, in_channel, preserve_weight=True):
middle_channel = make_divisible(round(in_channel * self.active_expand_ratio), 8)
channels_per_group = self.depth_conv.conv.channels_per_group

# build the new layer
sub_layer = MBConvLayer(
in_channel, self.active_out_channel, self.active_kernel_size, self.stride, self.active_expand_ratio,
act_func=self.act_func, mid_channels=middle_channel, use_se=self.use_se, channels_per_group=channels_per_group
)
sub_layer = sub_layer.to(self.parameters().__next__().device)

if not preserve_weight:
return sub_layer

# copy weight from current layer
if sub_layer.inverted_bottleneck is not None:
sub_layer.inverted_bottleneck.conv.weight.data.copy_(
self.inverted_bottleneck.conv.conv.weight.data[:middle_channel, :in_channel, :, :]
)
copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn)

sub_layer.depth_conv.conv.weight.data.copy_(
self.depth_conv.conv.get_active_filter(middle_channel, self.active_kernel_size).data
)
copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn)

if self.use_se:
se_mid = make_divisible(middle_channel // SEModule.REDUCTION, divisor=8)
sub_layer.depth_conv.se.fc.reduce.weight.data.copy_(
self.depth_conv.se.fc.reduce.weight.data[:se_mid, :middle_channel, :, :]
)
sub_layer.depth_conv.se.fc.reduce.bias.data.copy_(self.depth_conv.se.fc.reduce.bias.data[:se_mid])

sub_layer.depth_conv.se.fc.expand.weight.data.copy_(
self.depth_conv.se.fc.expand.weight.data[:middle_channel, :se_mid, :, :]
)
sub_layer.depth_conv.se.fc.expand.bias.data.copy_(self.depth_conv.se.fc.expand.bias.data[:middle_channel])

sub_layer.point_linear.conv.weight.data.copy_(
self.point_linear.conv.conv.weight.data[:self.active_out_channel, :middle_channel, :, :]
)
copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn)

return sub_layer

def re_organize_middle_weights(self, expand_ratio_stage=0):
# importance = torch.sum(torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3))
# if expand_ratio_stage > 0:
# sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
# sorted_expand_list.sort(reverse=True)
# target_width = sorted_expand_list[expand_ratio_stage]
# target_width = round(max(self.in_channel_list) * target_width)
# importance[target_width:] = torch.arange(0, target_width - importance.size(0), -1)
# sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
# self.point_linear.conv.conv.weight.data = torch.index_select(
# self.point_linear.conv.conv.weight.data, 1, sorted_idx
# )
# adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx)
# self.depth_conv.conv.conv.weight.data = torch.index_select(
# self.depth_conv.conv.conv.weight.data, 0, sorted_idx
# )

# if self.use_se:
# # se expand: output dim 0 reorganize
# se_expand = self.depth_conv.se.fc.expand
# se_expand.weight.data = torch.index_select(se_expand.weight.data, 0, sorted_idx)
# se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx)
# # se reduce: input dim 1 reorganize
# se_reduce = self.depth_conv.se.fc.reduce
# se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 1, sorted_idx)
# # middle weight reorganize
# se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3))
# se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True)

# se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx)
# se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx)
# se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx)
# # TODO if inverted_bottleneck is None, the previous layer should be reorganized accordingly
# if self.inverted_bottleneck is not None:
# adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx)
# self.inverted_bottleneck.conv.conv.weight.data = torch.index_select(
# self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx
# )
# return None
# else:
# return sorted_idx
raise NotImplementedError


class DynamicConvLayer(nn.Module):
def __init__(self, in_channel_list, out_channel_list, kernel_size=3, stride=1, dilation=1,
use_bn=True, act_func='relu6'):
super(DynamicConvLayer, self).__init__()
self.in_channel_list = val2list(in_channel_list)
self.out_channel_list = val2list(out_channel_list)
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.use_bn = use_bn
self.act_func = act_func
self.conv = DynamicConv2d(
max_in_channels=max(self.in_channel_list), max_out_channels=max(self.out_channel_list),
kernel_size=self.kernel_size, stride=self.stride, dilation=self.dilation,
)
if self.use_bn:
self.bn = DynamicBatchNorm2d(max(self.out_channel_list))

if self.act_func is not None:
self.act = build_activation(self.act_func, inplace=True)
self.active_out_channel = max(self.out_channel_list)
def forward(self, x):
self.conv.active_out_channel = self.active_out_channel
x = self.conv(x)
if self.use_bn:
x = self.bn(x)
if self.act_func is not None:
x = self.act(x)
return x
@property
def module_str(self):
return 'DyConv(O%d, K%d, S%d)' % (self.active_out_channel, self.kernel_size, self.stride)
@property
def config(self):
return {
'name': DynamicConvLayer.__name__,
'in_channel_list': self.in_channel_list,
'out_channel_list': self.out_channel_list,
'kernel_size': self.kernel_size,
'stride': self.stride,
'dilation': self.dilation,
'use_bn': self.use_bn,
'act_func': self.act_func,
}
@staticmethod
def build_from_config(config):
return DynamicConvLayer(**config)
def get_active_subnet(self, in_channel, preserve_weight=True):
sub_layer = ConvLayer(
in_channel, self.active_out_channel, self.kernel_size, self.stride, self.dilation,
use_bn=self.use_bn, act_func=self.act_func
)
sub_layer = sub_layer.to(self.parameters().__next__().device)
if not preserve_weight:
return sub_layer
sub_layer.conv.weight.data.copy_(self.conv.conv.weight.data[:self.active_out_channel, :in_channel, :, :])
if self.use_bn:
copy_bn(sub_layer.bn, self.bn.bn)
return sub_layer


class DynamicShortcutLayer(nn.Module):
def __init__(self, in_channel_list, out_channel_list, reduction=1):
super(DynamicShortcutLayer, self).__init__()
self.in_channel_list = val2list(in_channel_list)
self.out_channel_list = val2list(out_channel_list)
self.reduction = reduction
self.conv = DynamicConv2d(
max_in_channels=max(self.in_channel_list), max_out_channels=max(self.out_channel_list),
kernel_size=1, stride=1,
)

self.active_out_channel = max(self.out_channel_list)
def forward(self, x):
in_channel = x.size(1)

#identity mapping
if in_channel == self.active_out_channel and self.reduction == 1:
return x
#average pooling, if size doesn't match
if self.reduction > 1:
padding = 0 if x.size(-1) % 2 == 0 else 1
x = F.avg_pool2d(x, self.reduction, padding=padding)

#1*1 conv, if #channels doesn't match
if in_channel != self.active_out_channel:
self.conv.active_out_channel = self.active_out_channel
x = self.conv(x)
return x
@property
def module_str(self):
return 'DyShortcut(O%d, R%d)' % (self.active_out_channel, self.reduction)
@property
def config(self):
return {
'name': DynamicShortcutLayer.__name__,
'in_channel_list': self.in_channel_list,
'out_channel_list': self.out_channel_list,
'reduction': self.reduction,
}
@staticmethod
def build_from_config(config):
return DynamicShortcutLayer(**config)
def get_active_subnet(self, in_channel, preserve_weight=True):
sub_layer = ShortcutLayer(
in_channel, self.active_out_channel, self.reduction
)
sub_layer = sub_layer.to(self.parameters().__next__().device)
if not preserve_weight:
return sub_layer
sub_layer.conv.weight.data.copy_(self.conv.conv.weight.data[:self.active_out_channel, :in_channel, :, :])
return sub_layer


+ 181
- 0
xnas/spaces/BigNAS/dynamic_ops.py View File

@@ -0,0 +1,181 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.autograd.function import Function

from xnas.spaces.OFA.ops import get_same_padding
from xnas.spaces.OFA.dynamic_ops import sub_filter_start_end



class DynamicSeparableConv2d(nn.Module):
# KERNEL_TRANSFORM_MODE = None # None or 1
def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1, channels_per_group=1):
super(DynamicSeparableConv2d, self).__init__()
self.max_in_channels = max_in_channels
self.channels_per_group = channels_per_group
assert self.max_in_channels % self.channels_per_group == 0
self.kernel_size_list = kernel_size_list
self.stride = stride
self.dilation = dilation
self.conv = nn.Conv2d(
self.max_in_channels, self.max_in_channels, max(self.kernel_size_list), self.stride,
groups=self.max_in_channels // self.channels_per_group, bias=False,
)
self._ks_set = list(set(self.kernel_size_list))
self._ks_set.sort() # e.g., [3, 5, 7]
# if self.KERNEL_TRANSFORM_MODE is not None:
# # register scaling parameters
# # 7to5_matrix, 5to3_matrix
# scale_params = {}
# for i in range(len(self._ks_set) - 1):
# ks_small = self._ks_set[i]
# ks_larger = self._ks_set[i + 1]
# param_name = '%dto%d' % (ks_larger, ks_small)
# scale_params['%s_matrix' % param_name] = Parameter(torch.eye(ks_small ** 2))
# for name, param in scale_params.items():
# self.register_parameter(name, param)

self.active_kernel_size = max(self.kernel_size_list)
def get_active_filter(self, in_channel, kernel_size):
out_channel = in_channel
max_kernel_size = max(self.kernel_size_list)
start, end = sub_filter_start_end(max_kernel_size, kernel_size)
filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end]
# if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size:
# start_filter = self.conv.weight[:out_channel, :in_channel, :, :] # start with max kernel
# for i in range(len(self._ks_set) - 1, 0, -1):
# src_ks = self._ks_set[i]
# if src_ks <= kernel_size:
# break
# target_ks = self._ks_set[i - 1]
# start, end = sub_filter_start_end(src_ks, target_ks)
# _input_filter = start_filter[:, :, start:end, start:end]
# _input_filter = _input_filter.contiguous()
# _input_filter = _input_filter.view(_input_filter.size(0), _input_filter.size(1), -1)
# _input_filter = _input_filter.view(-1, _input_filter.size(2))
# _input_filter = F.linear(
# _input_filter, self.__getattr__('%dto%d_matrix' % (src_ks, target_ks)),
# )
# _input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks ** 2)
# _input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks, target_ks)
# start_filter = _input_filter
# filters = start_filter
return filters
def forward(self, x, kernel_size=None):
if kernel_size is None:
kernel_size = self.active_kernel_size
in_channel = x.size(1)
assert in_channel % self.channels_per_group == 0
filters = self.get_active_filter(in_channel, kernel_size).contiguous()
padding = get_same_padding(kernel_size)
y = F.conv2d(
x, filters, None, self.stride, padding, self.dilation, in_channel // self.channels_per_group
)
return y


class AllReduce(Function):
@staticmethod
def forward(ctx, input):
input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())]
# Use allgather instead of allreduce since I don't trust in-place operations ..
dist.all_gather(input_list, input, async_op=False)
inputs = torch.stack(input_list, dim=0)
return torch.sum(inputs, dim=0)

@staticmethod
def backward(ctx, grad_output):
dist.all_reduce(grad_output, async_op=False)
return grad_output


class DynamicBatchNorm2d(nn.Module):
'''
1. doesn't acculate bn statistics, (momentum=0.)
2. calculate BN statistics of all subnets after training
3. bn weights are shared
https://arxiv.org/abs/1903.05134
https://detectron2.readthedocs.io/_modules/detectron2/layers/batch_norm.html
'''
#SET_RUNNING_STATISTICS = False
def __init__(self, max_feature_dim):
super(DynamicBatchNorm2d, self).__init__()
self.max_feature_dim = max_feature_dim
self.bn = nn.BatchNorm2d(self.max_feature_dim)

# self.exponential_average_factor = 0 # doesn't acculate bn stats
self.need_sync = False # sync-batchnormalization, suggested to use in bignas

# reserved to tracking the performance of the largest and smallest network
self.bn_tracking = nn.ModuleList(
[
nn.BatchNorm2d(self.max_feature_dim, affine=False),
nn.BatchNorm2d(self.max_feature_dim, affine=False)
]
)

def forward(self, x):
feature_dim = x.size(1)
if not self.training:
raise ValueError('DynamicBN only supports training')
bn = self.bn
# need_sync
if not self.need_sync:
return F.batch_norm(
x, bn.running_mean[:feature_dim], bn.running_var[:feature_dim], bn.weight[:feature_dim],
bn.bias[:feature_dim], bn.training or not bn.track_running_stats,
bn.momentum, bn.eps,
)
else:
assert dist.get_world_size() > 1, 'SyncBatchNorm requires >1 world size'
B, C = x.shape[0], x.shape[1]
mean = torch.mean(x, dim=[0, 2, 3])
meansqr = torch.mean(x * x, dim=[0, 2, 3])
assert B > 0, 'does not support zero batch size'
vec = torch.cat([mean, meansqr], dim=0)
vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
mean, meansqr = torch.split(vec, C)

var = meansqr - mean * mean
invstd = torch.rsqrt(var + bn.eps)
scale = bn.weight[:feature_dim] * invstd
bias = bn.bias[:feature_dim] - mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return x * scale + bias


#if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS:
# return bn(x)
#else:
# exponential_average_factor = 0.0

# if bn.training and bn.track_running_stats:
# # TODO: if statement only here to tell the jit to skip emitting this when it is None
# if bn.num_batches_tracked is not None:
# bn.num_batches_tracked += 1
# if bn.momentum is None: # use cumulative moving average
# exponential_average_factor = 1.0 / float(bn.num_batches_tracked)
# else: # use exponential moving average
# exponential_average_factor = bn.momentum
# return F.batch_norm(
# x, bn.running_mean[:feature_dim], bn.running_var[:feature_dim], bn.weight[:feature_dim],
# bn.bias[:feature_dim], bn.training or not bn.track_running_stats,
# exponential_average_factor, bn.eps,
# )


+ 133
- 0
xnas/spaces/BigNAS/ops.py View File

@@ -0,0 +1,133 @@
import torch
import torch.nn as nn

from collections import OrderedDict
from xnas.spaces.OFA.ops import SEModule, build_activation
from xnas.spaces.OFA.utils import (
get_same_padding,
)

class MBConvLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
expand_ratio=6,
mid_channels=None,
act_func="relu6",
use_se=False,
channels_per_group=1,
):
super(MBConvLayer, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels

self.kernel_size = kernel_size
self.stride = stride
self.expand_ratio = expand_ratio
self.mid_channels = mid_channels
self.act_func = act_func
self.use_se = use_se
self.channels_per_group = channels_per_group

if self.mid_channels is None:
feature_dim = round(self.in_channels * self.expand_ratio)
else:
feature_dim = self.mid_channels

if self.expand_ratio == 1:
self.inverted_bottleneck = None
else:
self.inverted_bottleneck = nn.Sequential(OrderedDict([
("conv", nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]))

assert feature_dim % self.channels_per_group == 0
active_groups = feature_dim // self.channels_per_group
pad = get_same_padding(self.kernel_size)
# assert feature_dim % self.groups == 0
# active_groups = feature_dim // self.groups
depth_conv_modules = [
(
"conv",
nn.Conv2d(
feature_dim,
feature_dim,
kernel_size,
stride,
pad,
groups=active_groups,
bias=False,
),
),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]
if self.use_se:
depth_conv_modules.append(("se", SEModule(feature_dim)))
self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules))

self.point_linear = nn.Sequential(
OrderedDict(
[
("conv", nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
("bn", nn.BatchNorm2d(out_channels)),
]
)
)

def forward(self, x):
if self.inverted_bottleneck:
x = self.inverted_bottleneck(x)
x = self.depth_conv(x)
x = self.point_linear(x)
return x

@property
def module_str(self):
if self.mid_channels is None:
expand_ratio = self.expand_ratio
else:
expand_ratio = self.mid_channels // self.in_channels
layer_str = "%dx%d_MBConv%d_%s" % (
self.kernel_size,
self.kernel_size,
expand_ratio,
self.act_func.upper(),
)
if self.use_se:
layer_str = "SE_" + layer_str
layer_str += "_O%d" % self.out_channels
if self.groups is not None:
layer_str += "_G%d" % self.groups
if isinstance(self.point_linear.bn, nn.GroupNorm):
layer_str += "_GN%d" % self.point_linear.bn.num_groups
elif isinstance(self.point_linear.bn, nn.BatchNorm2d):
layer_str += "_BN"

return layer_str

@property
def config(self):
return {
"name": MBConvLayer.__name__,
"in_channels": self.in_channels,
"out_channels": self.out_channels,
"kernel_size": self.kernel_size,
"stride": self.stride,
"expand_ratio": self.expand_ratio,
"mid_channels": self.mid_channels,
"act_func": self.act_func,
"use_se": self.use_se,
"groups": self.groups,
}

@staticmethod
def build_from_config(config):
return MBConvLayer(**config)

+ 134
- 0
xnas/spaces/BigNAS/utils.py View File

@@ -0,0 +1,134 @@
# Implementation adapted from attentiveNAS - https://github.com/facebookresearch/AttentiveNAS

import torch
import torch.nn as nn
import copy
import math

multiply_adds = 1


def count_convNd(m, _, y):
cin = m.in_channels

kernel_ops = m.weight.size()[2] * m.weight.size()[3]
ops_per_element = kernel_ops
output_elements = y.nelement()

# cout x oW x oH
total_ops = cin * output_elements * ops_per_element // m.groups
m.total_ops = torch.Tensor([int(total_ops)])


def count_linear(m, _, __):
total_ops = m.in_features * m.out_features

m.total_ops = torch.Tensor([int(total_ops)])


register_hooks = {
nn.Conv1d: count_convNd,
nn.Conv2d: count_convNd,
nn.Conv3d: count_convNd,
######################################
nn.Linear: count_linear,
######################################
nn.Dropout: None,
nn.Dropout2d: None,
nn.Dropout3d: None,
nn.BatchNorm2d: None,
}


def profile(model, input_size=(1, 3, 224, 224), custom_ops=None):
handler_collection = []
custom_ops = {} if custom_ops is None else custom_ops

def add_hooks(m_):
if len(list(m_.children())) > 0:
return

m_.register_buffer('total_ops', torch.zeros(1))
m_.register_buffer('total_params', torch.zeros(1))

for p in m_.parameters():
m_.total_params += torch.Tensor([p.numel()])

m_type = type(m_)
fn = None

if m_type in custom_ops:
fn = custom_ops[m_type]
elif m_type in register_hooks:
fn = register_hooks[m_type]
else:
# print("Not implemented for ", m_)
pass

if fn is not None:
# print("Register FLOP counter for module %s" % str(m_))
_handler = m_.register_forward_hook(fn)
handler_collection.append(_handler)

original_device = model.parameters().__next__().device
training = model.training

model.eval()
model.apply(add_hooks)

x = torch.zeros(input_size).to(original_device)
with torch.no_grad():
model(x)

total_ops = 0
total_params = 0
for m in model.modules():
if len(list(m.children())) > 0: # skip for non-leaf module
continue
total_ops += m.total_ops
total_params += m.total_params

total_ops = total_ops.item()
total_params = total_params.item()

model.train(training)
model.to(original_device)

for handler in handler_collection:
handler.remove()

return total_ops, total_params


def count_net_flops_and_params(net, data_shape=(1, 3, 224, 224)):
if isinstance(net, nn.DataParallel):
net = net.module

net = copy.deepcopy(net)
flop, nparams = profile(net, data_shape)
return flop /1e6, nparams /1e6


def init_model(self, model_init="he_fout"):
""" Conv2d, BatchNorm2d, BatchNorm1d, Linear, """
for m in self.modules():
if isinstance(m, nn.Conv2d):
if model_init == 'he_fout':
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif model_init == 'he_fin':
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
else:
raise NotImplementedError
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
if m.affine:
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
stdv = 1. / math.sqrt(m.weight.size(1))
m.weight.data.uniform_(-stdv, stdv)
if m.bias is not None:
m.bias.data.zero_()

+ 14
- 10
xnas/spaces/OFA/ProxylessNet/cnn.py View File

@@ -140,14 +140,14 @@ class MobileNetV2(ProxylessNASNet):
width_mult=1.0, width_mult=1.0,
bn_param=(0.1, 1e-3), bn_param=(0.1, 1e-3),
dropout_rate=0.2, dropout_rate=0.2,
ks=None,
expand_ratio=None,
ks=None, # a list only include {3, 5, 7}
expand_ratio=None, # in proxyless space only 3 or 6
depth_param=None, depth_param=None,
stage_width_list=None, stage_width_list=None,
): ):


ks = 3 if ks is None else ks ks = 3 if ks is None else ks
expand_ratio = 6 if expand_ratio is None else expand_ratio
expand_ratio = [6]*6 if expand_ratio is None else expand_ratio


input_channel = 32 input_channel = 32
last_channel = 1280 last_channel = 1280
@@ -162,12 +162,12 @@ class MobileNetV2(ProxylessNASNet):
inverted_residual_setting = [ inverted_residual_setting = [
# t, c, n, s # t, c, n, s
[1, 16, 1, 1], [1, 16, 1, 1],
[expand_ratio, 24, 2, 2],
[expand_ratio, 32, 3, 2],
[expand_ratio, 64, 4, 2],
[expand_ratio, 96, 3, 1],
[expand_ratio, 160, 3, 2],
[expand_ratio, 320, 1, 1],
[None, 24, 2, 2],
[None, 32, 3, 2],
[None, 64, 4, 2],
[None, 96, 3, 1],
[None, 160, 3, 2],
[None, 320, 1, 1],
] ]


if depth_param is not None: if depth_param is not None:
@@ -179,6 +179,10 @@ class MobileNetV2(ProxylessNASNet):
for i in range(len(inverted_residual_setting)): for i in range(len(inverted_residual_setting)):
inverted_residual_setting[i][1] = stage_width_list[i] inverted_residual_setting[i][1] = stage_width_list[i]


if expand_ratio is not None:
for i in range(len(inverted_residual_setting)):
inverted_residual_setting[i][0] = expand_ratio[i]

ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1) ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1)
_pt = 0 _pt = 0


@@ -201,7 +205,7 @@ class MobileNetV2(ProxylessNASNet):
stride = s stride = s
else: else:
stride = 1 stride = 1
if t == 1:
if t == 1: # only used for first block
kernel_size = 3 kernel_size = 3
else: else:
kernel_size = ks[_pt] kernel_size = ks[_pt]


+ 8
- 26
xnas/spaces/OFA/dynamic_ops.py View File

@@ -488,36 +488,18 @@ class DynamicMBConvLayer(nn.Module):
self.use_se = use_se self.use_se = use_se


# build modules # build modules
max_middle_channel = make_divisible(
round(max(self.in_channel_list) * max(self.expand_ratio_list)))
max_middle_channel = make_divisible(round(max(self.in_channel_list) * max(self.expand_ratio_list)))
if max(self.expand_ratio_list) == 1: if max(self.expand_ratio_list) == 1:
self.inverted_bottleneck = None self.inverted_bottleneck = None
else: else:
self.inverted_bottleneck = nn.Sequential(
OrderedDict(
[
(
"conv",
DynamicConv2d(
max(self.in_channel_list), max_middle_channel
),
),
("bn", DynamicBatchNorm2d(max_middle_channel)),
("act", build_activation(self.act_func)),
]
)
)
self.inverted_bottleneck = nn.Sequential(OrderedDict([
("conv", DynamicConv2d(max(self.in_channel_list), max_middle_channel)),
("bn", DynamicBatchNorm2d(max_middle_channel)),
("act", build_activation(self.act_func)),
]))


self.depth_conv = nn.Sequential(
OrderedDict(
[
(
"conv",
DynamicSeparableConv2d(
max_middle_channel, self.kernel_size_list, self.stride,
kernel_trans=kernel_trans
),
),
self.depth_conv = nn.Sequential(OrderedDict([
("conv", DynamicSeparableConv2d(max_middle_channel, self.kernel_size_list, stride=self.stride, kernel_trans=kernel_trans)),
("bn", DynamicBatchNorm2d(max_middle_channel)), ("bn", DynamicBatchNorm2d(max_middle_channel)),
("act", build_activation(self.act_func)), ("act", build_activation(self.act_func)),
] ]


+ 88
- 18
xnas/spaces/OFA/ops.py View File

@@ -7,6 +7,7 @@ from xnas.spaces.OFA.utils import (
min_divisible_value, min_divisible_value,
get_same_padding, get_same_padding,
make_divisible, make_divisible,
drop_connect,
) )




@@ -46,12 +47,32 @@ def build_activation(act_func, inplace=True):
return Hswish(inplace=inplace) return Hswish(inplace=inplace)
elif act_func == "h_sigmoid": elif act_func == "h_sigmoid":
return Hsigmoid(inplace=inplace) return Hsigmoid(inplace=inplace)
elif act_func == 'swish':
return MemoryEfficientSwish()
elif act_func is None or act_func == "none": elif act_func is None or act_func == "none":
return None return None
else: else:
raise ValueError("do not support: %s" % act_func) raise ValueError("do not support: %s" % act_func)




class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result

@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_tensors[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

class MemoryEfficientSwish(nn.Module):
def forward(self, x):
return SwishImplementation.apply(x)

class Hswish(nn.Module): class Hswish(nn.Module):
def __init__(self, inplace=True): def __init__(self, inplace=True):
super(Hswish, self).__init__() super(Hswish, self).__init__()
@@ -637,27 +658,20 @@ class MBConvLayer(nn.Module):
if self.expand_ratio == 1: if self.expand_ratio == 1:
self.inverted_bottleneck = None self.inverted_bottleneck = None
else: else:
self.inverted_bottleneck = nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(
self.in_channels, feature_dim, 1, 1, 0, bias=False
),
),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]
)
)
self.inverted_bottleneck = nn.Sequential(OrderedDict([
("conv", nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]))


pad = get_same_padding(self.kernel_size) pad = get_same_padding(self.kernel_size)
groups = (
active_groups = (
feature_dim feature_dim
if self.groups is None if self.groups is None
else min_divisible_value(feature_dim, self.groups) else min_divisible_value(feature_dim, self.groups)
) )
# assert feature_dim % self.groups == 0
# active_groups = feature_dim // self.groups
depth_conv_modules = [ depth_conv_modules = [
( (
"conv", "conv",
@@ -667,7 +681,7 @@ class MBConvLayer(nn.Module):
kernel_size, kernel_size,
stride, stride,
pad, pad,
groups=groups,
groups=active_groups,
bias=False, bias=False,
), ),
), ),
@@ -739,19 +753,26 @@ class MBConvLayer(nn.Module):




class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, conv, shortcut):
def __init__(self, conv, shortcut, drop_connect_rate=0):
super(ResidualBlock, self).__init__() super(ResidualBlock, self).__init__()


self.conv = conv self.conv = conv
self.mobile_inverted_conv = self.conv # BigNAS
self.shortcut = shortcut self.shortcut = shortcut
self.drop_connect_rate = drop_connect_rate


def forward(self, x): def forward(self, x):
in_channel = x.size(1)
if self.conv is None or isinstance(self.conv, ZeroLayer): if self.conv is None or isinstance(self.conv, ZeroLayer):
res = x res = x
elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer): elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
res = self.conv(x) res = self.conv(x)
else: else:
res = self.conv(x) + self.shortcut(x)
im = self.shortcut(x)
x = self.conv(x)
if self.drop_connect_rate > 0 and in_channel == im.size(1) and self.shortcut.reduction == 1:
x = drop_connect(x, p=self.drop_connect_rate, training=self.training)
res = x + im
return res return res


@property @property
@@ -955,3 +976,52 @@ class ResNetBottleneckBlock(nn.Module):
@staticmethod @staticmethod
def build_from_config(config): def build_from_config(config):
return ResNetBottleneckBlock(**config) return ResNetBottleneckBlock(**config)


class ShortcutLayer(nn.Module):
"""
NOTE:
This class implements similar functionality to `IdentityLayer`,
but adds and removes part of the implementation.
"""
def __init__(self, in_channels, out_channels, reduction=1):
super(ShortcutLayer, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.reduction = reduction

self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)

def forward(self, x):
if self.reduction > 1:
padding = 0 if x.size(-1) % 2 == 0 else 1
x = F.avg_pool2d(x, self.reduction, padding=padding)
if self.in_channels != self.out_channels:
x = self.conv(x)
return x

@property
def module_str(self):
if self.in_channels == self.out_channels and self.reduction == 1:
conv_str = 'IdentityShortcut'
else:
if self.reduction == 1:
conv_str = '%d-%d_Shortcut' % (self.in_channels, self.out_channels)
else:
conv_str = '%d-%d_R%d_Shortcut' % (self.in_channels, self.out_channels, self.reduction)
return conv_str

@property
def config(self):
return {
'name': ShortcutLayer.__name__,
'in_channels': self.in_channels,
'out_channels': self.out_channels,
'reduction': self.reduction,
}

@staticmethod
def build_from_config(config):
return ShortcutLayer(**config)

+ 25
- 0
xnas/spaces/OFA/utils.py View File

@@ -47,6 +47,7 @@ def get_same_padding(kernel_size):
assert kernel_size % 2 > 0, "kernel size should be odd number" assert kernel_size % 2 > 0, "kernel size should be odd number"
return kernel_size // 2 return kernel_size // 2



def make_divisible(v, divisor=8, min_val=None): def make_divisible(v, divisor=8, min_val=None):
""" """
This function is taken from the original tf repo. This function is taken from the original tf repo.
@@ -67,6 +68,30 @@ def make_divisible(v, divisor=8, min_val=None):
return new_v return new_v




def drop_connect(inputs, p, training):
"""Drop connect.
Args:
input (tensor: BCWH): Input of this structure.
p (float: 0.0~1.0): Probability of drop connection.
training (bool): The running mode.
Returns:
output: Output after drop connection.
"""
assert 0 <= p <= 1, 'p must be in range of [0,1]'
if not training:
return inputs
batch_size = inputs.shape[0]
keep_prob = 1.0 - p

# generate binary_tensor mask according to probability (p for 0, 1-p for 1)
random_tensor = keep_prob
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
binary_tensor = torch.floor(random_tensor)

output = inputs / keep_prob * binary_tensor
return output


""" BN related """ """ BN related """


def clean_num_batch_tracked(net): def clean_num_batch_tracked(net):


+ 271
- 0
xnas/spaces/ProxylessNAS/cnn.py View File

@@ -0,0 +1,271 @@
import json
import numpy as np
import torch.nn as nn

from xnas.spaces.OFA.ops import (
set_layer_from_config,
MBConvLayer,
ConvLayer,
IdentityLayer,
LinearLayer,
ResidualBlock,
GlobalAvgPool2d,
)
from xnas.spaces.OFA.utils import val2list, make_divisible
from xnas.spaces.OFA.MobileNetV3.cnn import WSConv_Network


__all__ = ["proxyless_base", "ProxylessNASNet", "MobileNetV2"]


def proxyless_base(
net_config=None,
n_classes=None,
bn_param=None,
dropout_rate=None,
):
assert net_config is not None, "Please input a network config"
net_config_json = json.load(open(net_config, "r"))

if n_classes is not None:
net_config_json["classifier"]["out_features"] = n_classes
if dropout_rate is not None:
net_config_json["classifier"]["dropout_rate"] = dropout_rate

net = ProxylessNASNet.build_from_config(net_config_json)
if bn_param is not None:
net.set_bn_param(*bn_param)

return net


class ProxylessNASNet(WSConv_Network):
def __init__(self, first_conv, blocks, feature_mix_layer, classifier):
super(ProxylessNASNet, self).__init__()

self.first_conv = first_conv
self.blocks = nn.ModuleList(blocks)
self.feature_mix_layer = feature_mix_layer
self.global_avg_pool = GlobalAvgPool2d(keep_dim=False)
self.classifier = classifier

def forward(self, x):
x = self.first_conv(x)
for block in self.blocks:
x = block(x)
if self.feature_mix_layer is not None:
x = self.feature_mix_layer(x)
x = self.global_avg_pool(x)
x = self.classifier(x)
return x

@property
def module_str(self):
_str = self.first_conv.module_str + "\n"
for block in self.blocks:
_str += block.module_str + "\n"
_str += self.feature_mix_layer.module_str + "\n"
_str += self.global_avg_pool.__repr__() + "\n"
_str += self.classifier.module_str
return _str

@property
def config(self):
return {
"name": ProxylessNASNet.__name__,
"bn": self.get_bn_param(),
"first_conv": self.first_conv.config,
"blocks": [block.config for block in self.blocks],
"feature_mix_layer": None
if self.feature_mix_layer is None
else self.feature_mix_layer.config,
"classifier": self.classifier.config,
}

@staticmethod
def build_from_config(config):
first_conv = set_layer_from_config(config["first_conv"])
feature_mix_layer = set_layer_from_config(config["feature_mix_layer"])
classifier = set_layer_from_config(config["classifier"])

blocks = []
for block_config in config["blocks"]:
blocks.append(ResidualBlock.build_from_config(block_config))

net = ProxylessNASNet(first_conv, blocks, feature_mix_layer, classifier)
if "bn" in config:
net.set_bn_param(**config["bn"])
else:
net.set_bn_param(momentum=0.1, eps=1e-3)

return net

def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, ResidualBlock):
if isinstance(m.conv, MBConvLayer) and isinstance(
m.shortcut, IdentityLayer
):
m.conv.point_linear.bn.weight.data.zero_()

@property
def grouped_block_index(self):
info_list = []
block_index_list = []
for i, block in enumerate(self.blocks[1:], 1):
if block.shortcut is None and len(block_index_list) > 0:
info_list.append(block_index_list)
block_index_list = []
block_index_list.append(i)
if len(block_index_list) > 0:
info_list.append(block_index_list)
return info_list

def load_state_dict(self, state_dict, **kwargs):
current_state_dict = self.state_dict()

for key in state_dict:
if key not in current_state_dict:
assert ".mobile_inverted_conv." in key
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
else:
new_key = key
current_state_dict[new_key] = state_dict[key]
super(ProxylessNASNet, self).load_state_dict(current_state_dict)


class MobileNetV2(ProxylessNASNet):
def __init__(
self,
n_classes=1000,
width_mult=1.0,
bn_param=(0.1, 1e-3),
dropout_rate=0.2,
ks=None, # a list only include {3, 5, 7}
expand_ratio=None, # in proxyless space only 3 or 6
depth_param=None,
stage_width_list=None,
):

ks = 3 if ks is None else ks
expand_ratio = [6]*6 if expand_ratio is None else expand_ratio

input_channel = 32
last_channel = 1280

input_channel = make_divisible(input_channel * width_mult)
last_channel = (
make_divisible(last_channel * width_mult)
if width_mult > 1.0
else last_channel
)

inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[0, 24, 2, 2],
[0, 32, 3, 2],
[0, 64, 4, 2],
[0, 96, 3, 1],
[0, 160, 3, 2],
[0, 320, 1, 1],
]

if depth_param is not None:
assert len(depth_param) == 6
# assert isinstance(depth_param, )
for i in range(1, len(inverted_residual_setting) - 1):
inverted_residual_setting[i][2] = depth_param[i-1]

if stage_width_list is not None:
assert len(stage_width_list) == 7
for i in range(len(inverted_residual_setting)):
inverted_residual_setting[i][1] = stage_width_list[i]

# if expand_ratio is not None:
# for i in range(len(inverted_residual_setting)):
# inverted_residual_setting[i][0] = expand_ratio[i]

# ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1)
_pt = 0

self.feature_idx = np.cumsum(depth_param)[[1, 3]]
# first conv layer
first_conv = ConvLayer(
3,
input_channel,
kernel_size=3,
stride=2,
use_bn=True,
act_func="relu6",
ops_order="weight_bn_act",
)
# inverted residual blocks
blocks = []
for t, c, n, s in inverted_residual_setting:
output_channel = make_divisible(c * width_mult)
for i in range(n):
if i == 0:
stride = s
else:
stride = 1
if t == 1: # only used for first block
kernel_size = 3
er = 1
else:
kernel_size = ks[_pt].item()
er = expand_ratio[_pt].item()
_pt += 1
mobile_inverted_conv = MBConvLayer(
in_channels=input_channel,
out_channels=output_channel,
kernel_size=kernel_size,
stride=stride,
expand_ratio=er,
)
if stride == 1:
if input_channel == output_channel:
shortcut = IdentityLayer(input_channel, input_channel)
else:
shortcut = None
else:
shortcut = None
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
input_channel = output_channel
# 1x1_conv before global average pooling
feature_mix_layer = ConvLayer(
input_channel,
last_channel,
kernel_size=1,
use_bn=True,
act_func="relu6",
ops_order="weight_bn_act",
)

classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)

super(MobileNetV2, self).__init__(
first_conv, blocks, feature_mix_layer, classifier
)

# set bn param
self.set_bn_param(*bn_param)
def forward_with_features(self, x, *args, **kwargs):
x = self.first_conv(x)
features = []
for i, block in enumerate(self.blocks):
if i in (self.feature_idx):
features.append(x)
x = block(x)
if self.feature_mix_layer is not None:
x = self.feature_mix_layer(x)
features.append(x)
assert len(features) == 3
x = self.global_avg_pool(x)
logits = self.classifier(x)
return features, logits

def _MobileNetV2():
return MobileNetV2()

Loading…
Cancel
Save