#16 dev

Merged
xfey merged 5 commits from dev into master 1 year ago
  1. +1
    -0
      .gitignore
  2. +9
    -0
      .readthedocs.yml
  3. +2
    -3
      README.md
  4. +70
    -0
      docs/conf.py
  5. +12
    -2
      docs/data_preparation.md
  6. +7
    -7
      docs/get_started.md
  7. +35
    -0
      docs/index.rst
  8. +2
    -2
      docs/notes.md
  9. +4
    -0
      docs/requirements.txt
  10. +2
    -1
      scripts/search/DrNAS.py
  11. +101
    -0
      scripts/search/OFA/train_supernet.py
  12. +169
    -0
      scripts/search/SNG/search.py
  13. +11
    -7
      scripts/search/SPOS.py
  14. +166
    -0
      xnas/algorithms/OFA/progressive_shrinking.py
  15. +9
    -1
      xnas/core/builder.py
  16. +5
    -5
      xnas/datasets/imagenet.py
  17. +4
    -0
      xnas/datasets/loader.py
  18. +55
    -0
      xnas/runner/criterion.py
  19. +3
    -13
      xnas/runner/optimizer.py
  20. +40
    -28
      xnas/runner/trainer.py
  21. +3
    -0
      xnas/spaces/DARTS/cnn.py
  22. +395
    -0
      xnas/spaces/OFA/MobileNetV3/cnn.py
  23. +415
    -0
      xnas/spaces/OFA/MobileNetV3/ofa_cnn.py
  24. +242
    -0
      xnas/spaces/OFA/ProxylessNet/cnn.py
  25. +390
    -0
      xnas/spaces/OFA/ProxylessNet/ofa_cnn.py
  26. +248
    -0
      xnas/spaces/OFA/ResNets/cnn.py
  27. +353
    -0
      xnas/spaces/OFA/ResNets/ofa_cnn.py
  28. +1182
    -0
      xnas/spaces/OFA/dynamic_ops.py
  29. +957
    -0
      xnas/spaces/OFA/ops.py
  30. +111
    -0
      xnas/spaces/OFA/utils.py
  31. +3
    -0
      xnas/spaces/SPOS/cnn.py

+ 1
- 0
.gitignore View File

@@ -83,6 +83,7 @@ instance/

# Sphinx documentation
docs/_build/
docs/build/

# PyBuilder
target/


+ 9
- 0
.readthedocs.yml View File

@@ -0,0 +1,9 @@
version: 2

sphinx:
configuration: docs/conf.py

python:
version: 3.7
install:
- requirements: docs/requirements.txt

+ 2
- 3
README.md View File

@@ -67,13 +67,13 @@ We are gradually providing support for more settings.

To run XNAS, `python>=3.7` and `pytorch=1.9` are required. Other versions of `PyTorch` may also work well, but there are potential API differences that can cause warnings to be generated.

For detailed instructions, please refer to [**Getting_started.md**](./docs/Getting_started.md) and [**Data_preparation.md**](./docs/Data_preparation.md) in our docs.
For detailed instructions, please refer to [**get_started.md**](./docs/get_started.md) and [**data_preparation.md**](./docs/data_preparation.md) in our docs.

## Contributing

We welcome contributions to the library along with any potential issues or suggestions.

Please refer to [**Contributing.md**](./docs/Contributing.md) in our docs for more information.
Please refer to [**Contributing.md**](./docs/notes.md) in our docs for more information.

## Citation

@@ -110,7 +110,6 @@ This project is released under the [MIT license](https://mit-license.org).

- 迁移OFA代码
- 补充101&201安装测试
- 补完SNG Search代码
- 检查201搜索空间
- 检查RMINAS
- 补充模块测试案例


+ 70
- 0
docs/conf.py View File

@@ -0,0 +1,70 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))


# -- Project information -----------------------------------------------------

project = 'XNAS'
copyright = '2022, PCL_AutoML'
author = 'PCL_AutoML'

# The full version, including alpha/beta/rc tags
release = 'v0.0.1'


# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'recommonmark',
'sphinx_markdown_tables'
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = 'en'
# language = 'zh_CN'

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []


# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#

# html_theme = 'alabaster'
# html_theme = 'sphinx_rtd_theme'
html_theme = 'furo'



# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']

docs/Data_preparation.md → docs/data_preparation.md View File

@@ -1,4 +1,4 @@
## Data Preparation
# Common Settings

It is highly recommended to save or link datasets to the `$XNAS/data` folder, thus no additional configuration is required.

@@ -7,7 +7,7 @@ However, manually setting the path for datasets is also available by modifying t
Additionally, files required by benchmarks are also in the `$XNAS/data` folder. You can also modify related attributes under `cfg.BENCHMARK` in the configuration file, to match your actual file locations.


### Supported Datasets
# Dataset Preparation

The dataloaders of XNAS will read the dataset files from `$XNAS/data/$DATASET_NAME` by default, and we use lowercase filenames and remove the hyphens. For example, files for CIFAR-10 should be placed (or auto downloaded) under `$XNAS/data/cifar/` directory.

@@ -21,3 +21,13 @@ XNAS currently supports the following datasets.
- MNIST
- FashionMNIST

# Benchmark Preparation

Some search spaces or algorithms supported by XNAS require specific APIs provided by NAS benchmarks. Installation and properly setting are required to run these code.

Benchmarks supported by XNAS and their linkes are following.
- nasbench101: [GitHub](https://github.com/google-research/nasbench)
- nasbench1shot1: [GitHub](https://github.com/automl/nasbench-1shot1)
- nasbench201: [GitHub](https://github.com/D-X-Y/NAS-Bench-201)
- nasbench301: [GitHub](https://github.com/automl/nasbench301)


docs/Getting_started.md → docs/get_started.md View File

@@ -1,10 +1,10 @@
## Getting Started
# Prerequisites

XNAS does not provide installation via `pip` currently. To run XNAS, `python>=3.7` and `pytorch==1.9` are required. Other versions of `PyTorch` may also work well, but there are potential API differences that can cause warnings to be generated.

We have listed other requirements in `requirements.txt` file.

## Installation
# Installation

1. Clone this repo.
2. (Optional) Create a virtualenv for this library.
@@ -38,16 +38,16 @@ Benchmarks supported by XNAS and their linkes are following.
- nasbench201: [GitHub](https://github.com/D-X-Y/NAS-Bench-201)
- nasbench301: [GitHub](https://github.com/automl/nasbench301)

For detailed instructions to install these benchmarks, please refer to the `$XNAS/docs/benchmarks` directory.
For detailed instructions to install these benchmarks, please refer to [**Data Preparation**](./data_preparation.md).


## Usage
# Usage

Before running code in XNAS, please make sure you have followed instructions in [**Data_preparation.md**](./Data_preparation.md) in our docs to complete preparing the necessary data.
Before running code in XNAS, please make sure you have followed instructions in [**Data Preparation**](./data_preparation.md) in our docs to complete preparing the necessary data.

The main program entries for the search and training process are in the `$XNAS/scripts` folder. To modify and add NAS code, please place files in this folder.

### Configuration Files
## Configuration Files

XNAS uses the `.yaml` file format to organize the configuration files. All configuration files are placed under `$XNAS/configs` directory. To ensure the uniformity and clarity of files, we strongly recommend using the following naming convention:

@@ -61,7 +61,7 @@ For example, using `DARTS` algorithm, searching on `NASBench201` space and `CIFA
darts_nasbench201_cifar10_nasbench301_maxepoch75.yaml
```

### Running Examples
## Running Examples

XNAS reads configuration files from the command line. A simple running example is following:


+ 35
- 0
docs/index.rst View File

@@ -0,0 +1,35 @@
.. XNAS documentation master file, created by
sphinx-quickstart on Sat May 21 14:09:17 2022.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.


Welcome to XNAS' documentation!
====================================


.. toctree::
:maxdepth: 2
:caption: Get Started

get_started.md
.. toctree::
:maxdepth: 2
:caption: Data Preparation

data_preparation.md


.. toctree::
:maxdepth: 2
:caption: Notes

notes.md


Indices and tables
==================

* :ref:`genindex`
* :ref:`search`

docs/Contributing.md → docs/notes.md View File

@@ -1,8 +1,8 @@
## Contributing to XNAS
# Contributing

We welcome contributions to the library along with any potential issues or suggestions.

### Pull Requests
## Pull Requests

We actively welcome your pull requests.


+ 4
- 0
docs/requirements.txt View File

@@ -0,0 +1,4 @@
sphinx
furo
recommonmark
sphinx_markdown_tables

+ 2
- 1
scripts/search/DrNAS.py View File

@@ -81,7 +81,7 @@ def main():
# check whether warm-up training is used
if cfg.LOADER.BATCH_SIZE <= 256 or cfg.LOADER.DATASET != 'imagenet':
cfg.OPTIM.WARMUP_EPOCH = 0 # DrNAS does not warm-up if batch_size is small
train_epochs[0] -= cfg.OPTIM.WARMUP_EPOCH
train_epochs[0] += cfg.OPTIM.WARMUP_EPOCH
# init recorders
drnas_trainer = DartsTrainer(
@@ -120,6 +120,7 @@ def main():
writer=drnas_trainer.writer,
cur_epoch=cur_epoch
)
start_epoch += 1
# set tau for snas & gdas
if TAU_FLAG:
tau_epoch += tau_step


+ 101
- 0
scripts/search/OFA/train_supernet.py View File

@@ -0,0 +1,101 @@
"""OFA supernet training."""

import xnas.core.config as config
import xnas.logger.logging as logging
from xnas.core.config import cfg
from xnas.core.builder import *

# OFA
from xnas.runner.trainer import Trainer
from xnas.spaces.OFA.utils import init_model
from xnas.algorithms.OFA.progressive_shrinking import train_epoch, validate


# Load config and check
config.load_configs()
logger = logging.get_logger(__name__)

def main():
setup_env()
# Network
net = space_builder().cuda()
init_model(net)
# Loss function
criterion = criterion_builder()
# Data loaders
[train_loader, valid_loader] = construct_loader()
# Optimizers
net_params = [
# parameters with weight decay
{"params": net.get_parameters(['bn', 'bias'], mode="exclude"), "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
# parameters without weight decay
{"params": net.get_parameters(['bn', 'bias'], mode="include") , "weight_decay": 0},
]
# init_lr = cfg.OPTIM.BASE_LR * cfg.NUM_GPUS # TODO: multi-GPU support
optimizer = optimizer_builder("SGD", net_params)
lr_scheduler = lr_scheduler_builder(optimizer)
# Initialize Recorder
ofa_trainer = Trainer(
model=net,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
train_loader=train_loader,
test_loader=valid_loader,
)
# Resume
start_epoch = ofa_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0

# build validation config
validate_func_dict = {
"image_size_list": {cfg.TEST.IM_SIZE},
"ks_list": sorted({min(net.ks_list), max(net.ks_list)}),
"expand_ratio_list": sorted({min(net.expand_ratio_list), max(net.expand_ratio_list)}),
"depth_list": sorted({min(net.depth_list), max(net.depth_list)}),
}
if cfg.OFA.TASK == 'normal':
pass
elif cfg.OFA.TASK == 'kernel':
validate_func_dict["ks_list"] = sorted(net.ks_list)
elif cfg.OFA.TASK == 'depth':
# add depth list constraints
if (len(set(net.ks_list)) == 1) and (len(set(net.expand_ratio_list)) == 1):
validate_func_dict["depth_list"] = net.depth_list
elif cfg.OFA.TASK == 'expand':
if len(set(net.ks_list)) == 1 and len(set(net.depth_list)) == 1:
validate_func_dict["expand_ratio_list"] = net.expand_ratio_list
else:
raise NotImplementedError

# Training
logger.info("=== OFA | Task: {} | Phase: {} ===".format(cfg.OFA.TASK, cfg.OFA.PHASE))
ofa_trainer.start()
for cur_epoch in range(start_epoch, cfg.OPTIM.WARMUP_EPOCH+cfg.OPTIM.MAX_EPOCH):
train_epoch(
train_=train_loader,
net=net,
train_criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
writer=ofa_trainer.writer,
train_meter=ofa_trainer.train_meter,
cur_epoch=cur_epoch,
)
if (cur_epoch + 1) % cfg.SAVE_PERIOD == 0:
ofa_trainer.saving(cur_epoch)
if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH:
validate(
val_=valid_loader,
net=net,
val_meter=ofa_trainer.test_meter,
cur_epoch=cur_epoch,
logger=logger,
**validate_func_dict,
)
# TODO:OFA支持保存validation accuracy最高的checkpoint
ofa_trainer.finish()


if __name__ == '__main__':
main()

+ 169
- 0
scripts/search/SNG/search.py View File

@@ -0,0 +1,169 @@
"""SNG searching

(DARTS space only)
"""

import torch
import random
import numpy as np

import xnas.core.config as config
import xnas.logger.meter as meter
import xnas.logger.logging as logging
from xnas.core.config import cfg
from xnas.core.builder import *
from xnas.core.utils import index_to_one_hot, one_hot_to_index

from xnas.runner.trainer import OneShotTrainer


# Load config and check
config.load_configs()
logger = logging.get_logger(__name__)


def main():
setup_env()
search_space = space_builder().cuda()
criterion = criterion_builder().cuda()
evaluator = evaluator_builder()
[train_loader, valid_loader] = construct_loader()
w_optim = optimizer_builder("SGD", search_space.parameters())
lr_scheduler = lr_scheduler_builder(w_optim)
if cfg.SPACE.NAME in ['darts']:
distribution_optimizer = SNG_builder([search_space.num_ops]*search_space.all_edges)
else:
raise NotImplementedError

# Trainer definition
sng_trainer = OneShotTrainer(
supernet=search_space,
criterion=criterion,
optimizer=w_optim,
lr_scheduler=lr_scheduler,
train_loader=train_loader,
test_loader=valid_loader,
)
over_all_epoch = 0
# === Warmup ===
logger.info("=== Warmup Training ===")
sng_trainer.start()
for cur_epoch in range(cfg.OPTIM.WARMUP_EPOCH):
if cfg.SNG.WARMUP_RANDOM_SAMPLE:
sample = random_sampling(search_space, distribution_optimizer, cur_epoch)
else:
num_ops, total_edges = search_space.num_ops, search_space.all_edges
array_sample = [random.sample(list(range(num_ops)), num_ops) for i in range(total_edges)]
array_sample = np.array(array_sample)
for i in range(num_ops):
sample = np.transpose(array_sample[:, i])
sample = index_to_one_hot(sample, distribution_optimizer.p_model.Cmax)
logger.info("Warmup Sampling: {}".format(one_hot_to_index(sample)))
sng_trainer.train_epoch(over_all_epoch, sample)
sng_trainer.test_epoch(over_all_epoch, sample)
over_all_epoch += 1
sng_trainer.finish()
logger.info("=== Training ===")
sng_trainer.start()
for cur_epoch in range(cfg.OPTIM.MAX_EPOCH):
if hasattr(distribution_optimizer, 'training_finish'):
if distribution_optimizer.training_finish:
break
sample = random_sampling(search_space, distribution_optimizer, epoch=cur_epoch, _random=cfg.SNG.RANDOM_SAMPLE)
logger.info("Sampling: {}".format(one_hot_to_index(sample)))
sng_trainer.train_epoch(over_all_epoch, sample)
top1_err = sng_trainer.test_epoch(over_all_epoch, sample)
over_all_epoch += 1
# TODO: REA & RAND in algorithm/SPOS are similar to this optimizer. Adding them to SNG series?
distribution_optimizer.record_information(sample, top1_err)
distribution_optimizer.update()
# Evaluate the model
if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH:
logger.info("=== Optimal genotype at epoch: {} ===".format(cur_epoch))
logger.info(search_space.genotype(distribution_optimizer.p_model.theta))
logger.info("=== alphas at epoch: {} ===".format(cur_epoch))
for alpha in distribution_optimizer.p_model.theta:
logger.info(alpha)
sng_trainer.finish()
logger.info("=== Final epochs ===")
sng_trainer.start()
for cur_epoch in range(cfg.OPTIM.FINAL_EPOCH):
if cfg.SPACE.NAME in ['darts']:
genotype = search_space.genotype(distribution_optimizer.p_model.theta)
sample = search_space.genotype_to_onehot_sample(genotype)
else:
sample = distribution_optimizer.sampling_best()
sng_trainer.train_epoch(over_all_epoch, sample)
sng_trainer.test_epoch(over_all_epoch, sample)
over_all_epoch += 1
sng_trainer.finish()

if cfg.SPACE.NAME in ['darts']:
best_genotype = search_space.genotype(distribution_optimizer.p_model.theta)
# evaluator(genotype) # TODO: NAS-Bench-301 support.


def random_sampling(search_space, distribution_optimizer, epoch=-1000, _random=False):
"""random sampling"""
if _random:
num_ops, total_edges = search_space.num_ops, search_space.all_edges
# Edge importance
non_edge_idx = []
if cfg.SNG.EDGE_SAMPLING and epoch > cfg.SNG.EDGE_SAMPLING_EPOCH:
assert cfg.SPACE.NAME in ['darts'], "only support darts for now!"
norm_indexes = search_space.norm_node_index
non_edge_idx = []
for node in norm_indexes:
# DARTS: N=7 nodes
edge_non_prob = distribution_optimizer.p_model.theta[np.array(node), 7]
edge_non_prob = edge_non_prob / np.sum(edge_non_prob)
if len(node) == 2:
pass
else:
non_edge_sampling_num = len(node) - 2
non_edge_idx += list(np.random.choice(node, non_edge_sampling_num, p=edge_non_prob, replace=False))
# Big model sampling with probability
if random.random() < cfg.SNG.BIGMODEL_SAMPLE_PROB:
# Sample the network with high complexity
_num = 100
while _num > cfg.SNG.BIGMODEL_NON_PARA:
_error = False
if cfg.SNG.PROB_SAMPLING:
sample = np.array([np.random.choice(num_ops, 1, p=distribution_optimizer.p_model.theta[i, :])[0] for i in range(total_edges)])
else:
sample = np.array([np.random.choice(num_ops, 1)[0] for i in range(total_edges)])
_num = 0
for i in sample[0:search_space.num_edges]:
if i in non_edge_idx:
pass
elif i in search_space.non_op_idx:
if i == 7:
_error = True
_num = _num + 1
if _error:
_num = 100
else:
if cfg.SNG.PROB_SAMPLING:
sample = np.array([np.random.choice(num_ops, 1, p=distribution_optimizer.p_model.theta[i, :])[0]
for i in range(total_edges)])
else:
sample = np.array([np.random.choice(num_ops, 1)[0] for i in range(total_edges)])
if cfg.SNG.EDGE_SAMPLING and epoch > cfg.SNG.EDGE_SAMPLING_EPOCH:
for i in non_edge_idx:
sample[i] = 7
sample = index_to_one_hot(sample, distribution_optimizer.p_model.Cmax)
# in the pruning method we have to sampling anyway
distribution_optimizer.sampling()
return sample
else:
return distribution_optimizer.sampling()


if __name__ == "__main__":
main()

+ 11
- 7
scripts/search/SPOS.py View File

@@ -23,7 +23,8 @@ def main():
lr_scheduler = lr_scheduler_builder(optimizer)
# init sampler
train_sampler, evaluate_sampler = RAND(), REA()
train_sampler = RAND(cfg.SPOS.NUM_CHOICE, cfg.SPOS.LAYERS)
evaluate_sampler = REA(cfg.SPOS.NUM_CHOICE, cfg.SPOS.LAYERS)
# init recorders
spos_trainer = OneShotTrainer(
@@ -33,9 +34,9 @@ def main():
lr_scheduler=lr_scheduler,
train_loader=train_loader,
test_loader=valid_loader,
train_sampler=train_sampler,
evaluate_sampler=evaluate_sampler,
sample_type='iter'
)
spos_trainer.register_iter_sample(train_sampler)
# load checkpoint or initial weights
start_epoch = spos_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0
@@ -44,16 +45,19 @@ def main():
spos_trainer.start()
for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
# train epoch
spos_trainer.train_epoch(cur_epoch)
top1_err = spos_trainer.train_epoch(cur_epoch)
# test epoch
if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH:
spos_trainer.test_epoch(cur_epoch)
top1_err = spos_trainer.test_epoch(cur_epoch)
spos_trainer.finish()
# sample best architecture from supernet
best_arch, best_top1err = spos_trainer.best_arch()
for cycle in range(200): # NOTE: this should be a hyperparameter
sample = evaluate_sampler.suggest()
top1_err = spos_trainer.evaluate_epoch(sample)
evaluate_sampler.record(sample, top1_err)
best_arch, best_top1err = evaluate_sampler.final_best()
logger.info("Best arch: {} \nTop1 error: {}".format(best_arch, best_top1err))


if __name__ == '__main__':
main()

+ 166
- 0
xnas/algorithms/OFA/progressive_shrinking.py View File

@@ -0,0 +1,166 @@
import random
import torch
import torch.nn as nn

import xnas.logger.meter as meter
from xnas.core.config import cfg
from xnas.spaces.OFA.utils import list_mean

__all__ = [
"validate",
"train_one_epoch",
]


def validate(
val_, net, val_meter,
cur_epoch, logger,
image_size_list=None,
ks_list=None,
expand_ratio_list=None,
depth_list=None,
width_mult_list=None,
additional_setting=None,
):
# net
dynamic_net = net
if isinstance(dynamic_net, nn.DataParallel):
dynamic_net = dynamic_net.module
# eval mode
dynamic_net.eval()

# net config
assert image_size_list is not None, 'validate: image_size should not be None'

if ks_list is None:
ks_list = dynamic_net.ks_list
if expand_ratio_list is None:
expand_ratio_list = dynamic_net.expand_ratio_list
if depth_list is None:
depth_list = dynamic_net.depth_list
if width_mult_list is None:
if "width_mult_list" in dynamic_net.__dict__:
width_mult_list = list(range(len(dynamic_net.width_mult_list)))
else:
width_mult_list = [0]


# 获取所有subnet的setting
subnet_settings = []
# img_size = cfg.TEST.IM_SIZE
for d in depth_list:
for e in expand_ratio_list:
for k in ks_list:
for w in width_mult_list:
for img_size in image_size_list:
subnet_settings.append(
[
{
"img_size": img_size,
"d": d,
"e": e,
"ks": k,
"w": w,
},
"R%s-D%s-E%s-K%s-W%s" % (img_size, d, e, k, w),
]
)
if additional_setting is not None:
subnet_settings += additional_setting

# 遍历评估所有subnet
for setting, name in subnet_settings:
dynamic_net.set_active_subnet(**setting)
logger.info('epoch: '+str(cur_epoch+1) + ' || validate subnet: '+ name)
test_epoch(val_, dynamic_net, val_meter, cur_epoch)


def train_epoch(
train_, net, train_criterion,
optimizer, lr_scheduler, writer, train_meter,
cur_epoch,
):
nBatch = len(train_)
cur_step = cur_epoch*nBatch

for cur_iter, (images, labels) in enumerate(train_):
cur_step += 1
net.train()
train_meter.iter_tic() # 初始化时间

images, labels = images.cuda(), labels.cuda()

# clean gradients
net.zero_grad()

# set random seed before sampling
subnet_seed = int("%d%.3d" % (cur_step, 0))
random.seed(subnet_seed)
# subset setting
subnet_settings = net.sample_active_subnet()
subnet_str = ",".join(
[
"%s_%s"
% (
key,
"%.1f" % list_mean([val[0]])
if isinstance(val, list)
else val,
)
for key, val in subnet_settings.items()
]
)
# compute output
output = net(images)
loss = train_criterion(output, labels)
loss.backward()
optimizer.step()

# measure top1&top5 error
top1_err, top5_err = meter.topk_errors(output, labels, [1, 5])
# Copy the stats from GPU to CPU (sync point)
loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item()
train_meter.iter_toc()
# Update and log stats
mb_size = images.size(0)
cur_lr = lr_scheduler.get_last_lr()[0]
train_meter.update_stats(top1_err, top5_err, loss, cur_lr, mb_size)
train_meter.log_iter_stats(cur_epoch, cur_iter)
train_meter.iter_tic()
# write to tensorboard
writer.add_scalar('train/lr', cur_lr, cur_step)
writer.add_scalar('train/loss', loss, cur_step)
writer.add_scalar('train/top1_error', top1_err, cur_step)
writer.add_scalar('train/top5_error', top5_err, cur_step)
# update lr
lr_scheduler.step(cur_epoch + cur_iter / nBatch)
# Log epoch stats
train_meter.log_epoch_stats(cur_epoch)
train_meter.reset()
return


@torch.no_grad()
def test_epoch(test_loader, model, test_meter, cur_epoch, writer=None):
model.eval()
test_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(test_loader):
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
preds = model(inputs)
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item()
test_meter.iter_toc()
test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
test_meter.log_iter_stats(cur_epoch, cur_iter)
test_meter.iter_tic()
top1_err = test_meter.mb_top1_err.get_win_median()
if writer is not None:
writer.add_scalar('val/top1_error', test_meter.mb_top1_err.get_win_median(), cur_epoch)
writer.add_scalar('val/top5_error', test_meter.mb_top5_err.get_win_median(), cur_epoch)
# Log epoch stats
test_meter.log_epoch_stats(cur_epoch)
test_meter.reset()
return top1_err

+ 9
- 1
xnas/core/builder.py View File

@@ -21,7 +21,8 @@ from xnas.core.config import cfg
# Dataloader
from xnas.datasets.loader import construct_loader
# Optimizers, criterions and LR_schedulers
from xnas.runner.optimizer import optimizer_builder, criterion_builder
from xnas.runner.optimizer import optimizer_builder
from xnas.runner.criterion import criterion_builder
from xnas.runner.scheduler import lr_scheduler_builder


@@ -31,6 +32,7 @@ __all__ = [
'criterion_builder',
'lr_scheduler_builder',
'space_builder',
'SNG_builder',
'evaluator_builder',
'setup_env',
]
@@ -48,6 +50,9 @@ from xnas.spaces.DrNAS.darts_cnn import _DrNAS_DARTS_CNN
from xnas.spaces.DrNAS.nb201_cnn import _DrNAS_nb201_CNN, _GDAS_nb201_CNN
from xnas.spaces.SPOS.cnn import _SPOS_CNN, _infer_SPOS_CNN
from xnas.spaces.DropNAS.cnn import _DropNASCNN
from xnas.spaces.OFA.MobileNetV3.ofa_cnn import _OFAMobileNetV3
from xnas.spaces.OFA.ProxylessNet.ofa_cnn import _OFAProxylessNASNet
from xnas.spaces.OFA.ResNets.ofa_cnn import _OFAResNet


SUPPORTED_SPACES = {
@@ -60,6 +65,9 @@ SUPPORTED_SPACES = {
"gdas_nb201": _GDAS_nb201_CNN,
"dropnas": _DropNASCNN,
"spos": _SPOS_CNN,
"ofa_mbv3": _OFAMobileNetV3,
"ofa_proxyless": _OFAProxylessNASNet,
"ofa_resnet": _OFAResNet,
# models for inference
"infer_darts": _infer_DartsCNN,
"infer_nb201": _infer_NASBench201,


+ 5
- 5
xnas/datasets/imagenet.py View File

@@ -33,8 +33,8 @@ class ImageFolder():
_rgb_normalized_mean=None,
_rgb_normalized_std=None,
transforms=None,
num_workers=cfg.LOADER.NUM_WORKERS,
pin_memory=cfg.LOADER.PIN_MEMORY,
num_workers=None,
pin_memory=None,
shuffle=True
):
assert os.path.exists(datapath), "Data path '{}' not found".format(datapath)
@@ -42,10 +42,10 @@ class ImageFolder():
self._data_path, self._split, self.dataset_name = datapath, split, dataset_name
self._rgb_normalized_mean, self._rgb_normalized_std = _rgb_normalized_mean, _rgb_normalized_std
self.num_workers = num_workers
self.batch_size = batch_size
self.pin_memory = pin_memory
self.num_workers = cfg.LOADER.NUM_WORKERS if num_workers is None else num_workers
self.pin_memory = cfg.LOADER.PIN_MEMORY if pin_memory is None else pin_memory
self.shuffle = shuffle
self.batch_size = batch_size
if transforms is None:
self.transforms = [{'crop': 'random', 'crop_size': cfg.SEARCH.IM_SIZE, 'min_crop': 0.08, 'random_flip': True},
{'crop': 'center', 'crop_size': cfg.TEST.IM_SIZE, 'min_crop': -1, 'random_flip': False}] # NOTE: min_crop is not used here.


+ 4
- 0
xnas/datasets/loader.py View File

@@ -43,6 +43,10 @@ def construct_loader(
batch_size = [256, 256]
assert len(batch_size) == len(split), "lengths of batch_size and split should be same."
# check if randomresized crop is used only in ImageFolder type datasets
if isinstance(cfg.SEARCH.IMG_SIZE, list):
assert name in IMAGEFOLDER_FORMAT, "RandomResizedCrop can only be used in ImageFolder currently."
if name in SUPPORTED_DATASETS:
# using training data only.
train_data, _ = get_data(name, datapath, cutout_length, use_classes=use_classes, transforms=transforms)


+ 55
- 0
xnas/runner/criterion.py View File

@@ -0,0 +1,55 @@
"""Loss functions."""

import torch
import torch.nn as nn
import torch.nn.functional as F
from xnas.core.config import cfg


__all__ = ['criterion_builder']


def smoothed_cross_entropy_loss(pred, target, label_smoothing=0.):
def _label_smooth(target, n_classes: int, label_smoothing):
# convert to one-hot
batch_size = target.size(0)
target = torch.unsqueeze(target, 1)
soft_target = torch.zeros((batch_size, n_classes), device=target.device)
soft_target.scatter_(1, target, 1)
# label smoothing
soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
return soft_target

label_smoothing = cfg.SEARCH.LABEL_SMOOTH if label_smoothing == 0. else label_smoothing
soft_target = _label_smooth(target, pred.size(1), label_smoothing)
logsoftmax = nn.LogSoftmax()
return torch.mean(torch.sum(-soft_target * logsoftmax(pred), 1))


class MultiHeadCrossEntropyLoss(nn.Module):
def forward(self, preds, targets):
assert preds.dim() == 3, preds
assert targets.dim() == 2, targets

assert preds.size(1) == targets.size(1), (preds, targets)
num_heads = targets.size(1)

loss = 0
for k in range(num_heads):
loss += F.cross_entropy(preds[:, k, :], targets[:, k]) / num_heads
return loss


# ----------

SUPPORTED_CRITERIONS = {
"cross_entropy": torch.nn.CrossEntropyLoss(),
"cross_entropy_smooth": smoothed_cross_entropy_loss,
"cross_entropy_multihead": MultiHeadCrossEntropyLoss()
}


def criterion_builder():
err_str = "Loss function type '{}' not supported"
assert cfg.SEARCH.LOSS_FUN in SUPPORTED_CRITERIONS.keys(), err_str.format(cfg.SEARCH.LOSS_FUN)
return SUPPORTED_CRITERIONS[cfg.SEARCH.LOSS_FUN]

+ 3
- 13
xnas/runner/optimizer.py View File

@@ -1,13 +1,13 @@
"""Optimizers and loss functions."""
"""Optimizers."""

import torch
import torch.nn as nn
from xnas.core.config import cfg


__all__ = [
'optimizer_builder',
'darts_alpha_optimizer',
'criterion_builder'
'darts_alpha_optimizer',
]


@@ -16,10 +16,6 @@ SUPPORTED_OPTIMIZERS = {
"Adam",
}

SUPPORTED_CRITERIONS = {
"cross_entropy": torch.nn.CrossEntropyLoss(),
}


def optimizer_builder(name, param):
"""optimizer builder
@@ -69,9 +65,3 @@ def darts_alpha_optimizer(name, param):
betas=(0.5, 0.999),
weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY,
)


def criterion_builder():
err_str = "Loss function type '{}' not supported"
assert cfg.SEARCH.LOSS_FUN in SUPPORTED_CRITERIONS.keys(), err_str.format(cfg.SEARCH.LOSS_FUN)
return SUPPORTED_CRITERIONS[cfg.SEARCH.LOSS_FUN]

+ 40
- 28
xnas/runner/trainer.py View File

@@ -22,7 +22,7 @@ from xnas.core.config import cfg
from torch.utils.tensorboard import SummaryWriter


__all__ = ["Trainer", "DartsTrainer"]
__all__ = ["Trainer", "DartsTrainer", "OneShotTrainer"]


logger = logging.get_logger(__name__)
@@ -47,6 +47,7 @@ class Recorder():
self.full_timer.toc()
logger.info("Overall time cost: {}".format(str(self.full_timer.total_time)))
gc.collect()
self.full_timer = None


class Trainer(Recorder):
@@ -154,8 +155,10 @@ class Trainer(Recorder):
"""Load from checkpoint."""
ckpt_epoch, ckpt_dict = self.resume()
if ckpt_epoch != -1:
self.optimizer.load_state_dict(ckpt_dict['optimizer'])
self.lr_scheduler.load_state_dict(ckpt_dict['lr_scheduler'])
if self.optimizer is not None:
self.optimizer.load_state_dict(ckpt_dict['optimizer'])
if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(ckpt_dict['lr_scheduler'])
return ckpt_epoch + 1
else:
return 0
@@ -261,7 +264,7 @@ class DartsTrainer(Trainer):


class OneShotTrainer(Trainer):
def __init__(self, supernet, criterion, optimizer, lr_scheduler, train_loader, test_loader, train_sampler, evaluate_sampler):
def __init__(self, supernet, criterion, optimizer, lr_scheduler, train_loader, test_loader, sample_type='epoch'):
super().__init__(
model=supernet,
criterion=criterion,
@@ -269,11 +272,15 @@ class OneShotTrainer(Trainer):
lr_scheduler=lr_scheduler,
train_loader=train_loader,
test_loader=test_loader)
self.train_sampler = train_sampler
self.evaluate_sampler = evaluate_sampler
self.iter_sampler = None
self.sample_type = sample_type
assert self.sample_type in ['epoch', 'iter']
self.evaluate_meter = meter.TestMeter(len(self.test_loader))
def register_iter_sample(self, sampler):
self.iter_sampler = sampler

def train_epoch(self, cur_epoch):
def train_epoch(self, cur_epoch, sample=None):
"""Sample path from supernet and train it."""
self.model.train()
lr = self.lr_scheduler.get_last_lr()[0]
@@ -283,8 +290,9 @@ class OneShotTrainer(Trainer):
for cur_iter, (inputs, labels) in enumerate(self.train_loader):
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
# sample subnet
choice = self.train_sampler.suggest()
preds = self.model(inputs, choice)
if self.sample_type == 'iter':
sample = self.iter_sampler.suggest()
preds = self.model(inputs, sample)
loss = self.criterion(preds, labels)
self.optimizer.zero_grad()
loss.backward()
@@ -294,7 +302,8 @@ class OneShotTrainer(Trainer):
# Compute the errors
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item()
self.train_sampler.record(choice, top1_err) # use top1_err as evaluation
if self.sample_type == 'iter':
self.iter_sampler.record(sample, top1_err) # use top1_err as evaluation
self.train_meter.iter_toc()
# Update and log stats
self.train_meter.update_stats(top1_err, top5_err, loss, lr, inputs.size(0))
@@ -311,19 +320,22 @@ class OneShotTrainer(Trainer):
# Saving checkpoint
if (cur_epoch + 1) % cfg.SAVE_PERIOD == 0:
self.saving(cur_epoch)
return top1_err

@torch.no_grad()
def test_epoch(self, cur_epoch):
def test_epoch(self, cur_epoch, sample=None):
self.model.eval()
self.test_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(self.test_loader):
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
# sample subnet
choice = self.train_sampler.suggest()
preds = self.model(inputs, choice)
if self.sample_type == 'iter':
sample = self.iter_sampler.suggest()
preds = self.model(inputs, sample)
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item()
self.train_sampler.record(choice, top1_err)
if self.sample_type == 'iter':
self.iter_sampler.record(sample, top1_err) # use top1_err as evaluation
self.test_meter.iter_toc()
self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
self.test_meter.log_iter_stats(cur_epoch, cur_iter)
@@ -338,20 +350,20 @@ class OneShotTrainer(Trainer):
if self.best_err > top1_err:
self.best_err = top1_err
self.saving(cur_epoch, best=True)
return top1_err
@torch.no_grad()
def best_arch(self, cycles):
"""Return final best subnet architecture and its performance"""
def evaluate_epoch(self, sample):
"""Return performance of the given sample (subnet)"""
self.model.eval()
for c in range(cycles):
choice = self.evaluate_sampler.suggest()
for cur_iter, (inputs, labels) in enumerate(self.test_loader):
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
preds = self.model(inputs, choice)
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item()
self.evaluate_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
top1_err = self.evaluate_meter.mb_top1_err.get_win_median()
self.evaluate_sampler.record(choice, top1_err)
self.evaluate_meter.reset()
return self.evaluate_sampler.final_best()
# choice = self.evaluate_sampler.suggest()
for cur_iter, (inputs, labels) in enumerate(self.test_loader):
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
preds = self.model(inputs, sample)
top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5])
top1_err, top5_err = top1_err.item(), top5_err.item()
self.evaluate_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
top1_err = self.evaluate_meter.mb_top1_err.get_win_median()
# self.evaluate_sampler.record(choice, top1_err)
self.evaluate_meter.reset()
return top1_err

+ 3
- 0
xnas/spaces/DARTS/cnn.py View File

@@ -100,6 +100,9 @@ class DartsCNN(nn.Module):
self.norm_node_index = self._node_index(n_nodes, input_nodes=2, start_index=0)
self.reduce_node_index = self._node_index(n_nodes, input_nodes=2, start_index=self.num_edges)

def weights(self):
return self.parameters()
def forward(self, x, sample):
s0 = s1 = self.stem(x)



+ 395
- 0
xnas/spaces/OFA/MobileNetV3/cnn.py View File

@@ -0,0 +1,395 @@
import torch
import torch.nn as nn
from copy import deepcopy

from xnas.spaces.OFA.ops import *
from xnas.spaces.OFA.utils import min_divisible_value


__all__ = ["WSConv_Network", "MobileNetV3", "MobileNetV3Large"]


class WSConv_Network(nn.Module):
"""Network with all Conv2d replaced by Weight Standard Conv2d."""

def set_bn_param(self, momentum, eps, gn_channel_per_group=None, ws_eps=None, **kwargs):
"""Replace BN with GN"""
if gn_channel_per_group is None:
return

for m in self.modules():
to_replace_dict = {}
for name, sub_m in m.named_children():
if isinstance(sub_m, nn.BatchNorm2d):
num_groups = sub_m.num_features // min_divisible_value(
sub_m.num_features, gn_channel_per_group
)
gn_m = nn.GroupNorm(
num_groups=num_groups,
num_channels=sub_m.num_features,
eps=sub_m.eps,
affine=True,
)

# load weight
gn_m.weight.data.copy_(sub_m.weight.data)
gn_m.bias.data.copy_(sub_m.bias.data)
# load requires_grad
gn_m.weight.requires_grad = sub_m.weight.requires_grad
gn_m.bias.requires_grad = sub_m.bias.requires_grad

to_replace_dict[name] = gn_m
m._modules.update(to_replace_dict)

"""Init Norm params"""
for m in self.modules():
if type(m) in [nn.BatchNorm1d, nn.BatchNorm2d]:
m.momentum = momentum
m.eps = eps
elif isinstance(m, nn.GroupNorm):
m.eps = eps

"""Replace Conv2d with WeightStandardConv2d"""
if ws_eps is None:
return

for m in self.modules():
to_update_dict = {}
for name, sub_module in m.named_children():
if isinstance(sub_module, nn.Conv2d) and not sub_module.bias:
# only replace conv2d layers that are followed by normalization layers (i.e., no bias)
to_update_dict[name] = sub_module
for name, sub_module in to_update_dict.items():
m._modules[name] = WeightStandardConv2d(
sub_module.in_channels,
sub_module.out_channels,
sub_module.kernel_size,
sub_module.stride,
sub_module.padding,
sub_module.dilation,
sub_module.groups,
sub_module.bias,
)
# load weight
m._modules[name].load_state_dict(sub_module.state_dict())
# load requires_grad
m._modules[name].weight.requires_grad = sub_module.weight.requires_grad
if sub_module.bias is not None:
m._modules[name].bias.requires_grad = sub_module.bias.requires_grad
# set ws_eps
for m in self.modules():
if isinstance(m, WeightStandardConv2d):
m.WS_EPS = ws_eps

def get_bn_param(self):
ws_eps = None
for m in self.modules():
if isinstance(m, WeightStandardConv2d):
ws_eps = m.WS_EPS
break
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
return {
"momentum": m.momentum,
"eps": m.eps,
"ws_eps": ws_eps,
}
elif isinstance(m, nn.GroupNorm):
return {
"momentum": None,
"eps": m.eps,
"gn_channel_per_group": m.num_channels // m.num_groups,
"ws_eps": ws_eps,
}
return None

def get_parameters(self, keys=None, mode="include"):
if keys is None:
for name, param in self.named_parameters():
if param.requires_grad:
yield param
elif mode == "include":
for name, param in self.named_parameters():
flag = False
for key in keys:
if key in name:
flag = True
break
if flag and param.requires_grad:
yield param
elif mode == "exclude":
for name, param in self.named_parameters():
flag = True
for key in keys:
if key in name:
flag = False
break
if flag and param.requires_grad:
yield param
else:
raise ValueError("do not support: %s" % mode)


class MobileNetV3(WSConv_Network):
def __init__(
self, first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
):
super(MobileNetV3, self).__init__()

self.first_conv = first_conv
self.blocks = nn.ModuleList(blocks)
self.final_expand_layer = final_expand_layer
self.global_avg_pool = GlobalAvgPool2d(keep_dim=True)
self.feature_mix_layer = feature_mix_layer
self.classifier = classifier

def forward(self, x):
x = self.first_conv(x)
for block in self.blocks:
x = block(x)
x = self.final_expand_layer(x)
x = self.global_avg_pool(x) # global average pooling
x = self.feature_mix_layer(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x

@property
def module_str(self):
_str = self.first_conv.module_str + "\n"
for block in self.blocks:
_str += block.module_str + "\n"
_str += self.final_expand_layer.module_str + "\n"
_str += self.global_avg_pool.__repr__() + "\n"
_str += self.feature_mix_layer.module_str + "\n"
_str += self.classifier.module_str
return _str

@property
def config(self):
return {
"name": MobileNetV3.__name__,
"bn": self.get_bn_param(),
"first_conv": self.first_conv.config,
"blocks": [block.config for block in self.blocks],
"final_expand_layer": self.final_expand_layer.config,
"feature_mix_layer": self.feature_mix_layer.config,
"classifier": self.classifier.config,
}

@staticmethod
def build_from_config(config):
first_conv = set_layer_from_config(config["first_conv"])
final_expand_layer = set_layer_from_config(config["final_expand_layer"])
feature_mix_layer = set_layer_from_config(config["feature_mix_layer"])
classifier = set_layer_from_config(config["classifier"])

blocks = []
for block_config in config["blocks"]:
blocks.append(ResidualBlock.build_from_config(block_config))

net = MobileNetV3(
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
)
if "bn" in config:
net.set_bn_param(**config["bn"])
else:
net.set_bn_param(momentum=0.1, eps=1e-5)

return net

def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, ResidualBlock):
if isinstance(m.conv, MBConvLayer) and isinstance(
m.shortcut, IdentityLayer
):
m.conv.point_linear.bn.weight.data.zero_()

@property
def grouped_block_index(self):
info_list = []
block_index_list = []
for i, block in enumerate(self.blocks[1:], 1):
if block.shortcut is None and len(block_index_list) > 0:
info_list.append(block_index_list)
block_index_list = []
block_index_list.append(i)
if len(block_index_list) > 0:
info_list.append(block_index_list)
return info_list

@staticmethod
def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
# first conv layer
first_conv = ConvLayer(
3,
input_channel,
kernel_size=3,
stride=2,
use_bn=True,
act_func="h_swish",
ops_order="weight_bn_act",
)
# build mobile blocks
feature_dim = input_channel
blocks = []
for stage_id, block_config_list in cfg.items():
for (
k,
mid_channel,
out_channel,
use_se,
act_func,
stride,
expand_ratio,
) in block_config_list:
mb_conv = MBConvLayer(
feature_dim,
out_channel,
k,
stride,
expand_ratio,
mid_channel,
act_func,
use_se,
)
if stride == 1 and out_channel == feature_dim:
shortcut = IdentityLayer(out_channel, out_channel)
else:
shortcut = None
blocks.append(ResidualBlock(mb_conv, shortcut))
feature_dim = out_channel
# final expand layer
final_expand_layer = ConvLayer(
feature_dim,
feature_dim * 6,
kernel_size=1,
use_bn=True,
act_func="h_swish",
ops_order="weight_bn_act",
)
# feature mix layer
feature_mix_layer = ConvLayer(
feature_dim * 6,
last_channel,
kernel_size=1,
bias=False,
use_bn=False,
act_func="h_swish",
)
# classifier
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)

return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier

@staticmethod
def adjust_cfg(
cfg, ks=None, expand_ratio=None, depth_param=None, stage_width_list=None
):
for i, (stage_id, block_config_list) in enumerate(cfg.items()):
for block_config in block_config_list:
if ks is not None and stage_id != "0":
block_config[0] = ks
if expand_ratio is not None and stage_id != "0":
block_config[-1] = expand_ratio
block_config[1] = None
if stage_width_list is not None:
block_config[2] = stage_width_list[i]
if depth_param is not None and stage_id != "0":
new_block_config_list = [block_config_list[0]]
new_block_config_list += [
deepcopy(block_config_list[-1]) for _ in range(depth_param - 1)
]
cfg[stage_id] = new_block_config_list
return cfg

def load_state_dict(self, state_dict, **kwargs):
current_state_dict = self.state_dict()

for key in state_dict:
if key not in current_state_dict:
assert ".mobile_inverted_conv." in key
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
else:
new_key = key
current_state_dict[new_key] = state_dict[key]
super(MobileNetV3, self).load_state_dict(current_state_dict)


class MobileNetV3Large(MobileNetV3):
def __init__(
self,
n_classes=1000,
width_mult=1.0,
bn_param=(0.1, 1e-5),
dropout_rate=0.2,
ks=None,
expand_ratio=None,
depth_param=None,
stage_width_list=None,
):
input_channel = 16
last_channel = 1280

input_channel = make_divisible(input_channel * width_mult)
last_channel = (
make_divisible(last_channel * width_mult)
if width_mult > 1.0
else last_channel
)

cfg = {
# k, exp, c, se, nl, s, e,
"0": [
[3, 16, 16, False, "relu", 1, 1],
],
"1": [
[3, 64, 24, False, "relu", 2, None], # 4
[3, 72, 24, False, "relu", 1, None], # 3
],
"2": [
[5, 72, 40, True, "relu", 2, None], # 3
[5, 120, 40, True, "relu", 1, None], # 3
[5, 120, 40, True, "relu", 1, None], # 3
],
"3": [
[3, 240, 80, False, "h_swish", 2, None], # 6
[3, 200, 80, False, "h_swish", 1, None], # 2.5
[3, 184, 80, False, "h_swish", 1, None], # 2.3
[3, 184, 80, False, "h_swish", 1, None], # 2.3
],
"4": [
[3, 480, 112, True, "h_swish", 1, None], # 6
[3, 672, 112, True, "h_swish", 1, None], # 6
],
"5": [
[5, 672, 160, True, "h_swish", 2, None], # 6
[5, 960, 160, True, "h_swish", 1, None], # 6
[5, 960, 160, True, "h_swish", 1, None], # 6
],
}

cfg = self.adjust_cfg(cfg, ks, expand_ratio, depth_param, stage_width_list)
# width multiplier on mobile setting, change `exp: 1` and `c: 2`
for stage_id, block_config_list in cfg.items():
for block_config in block_config_list:
if block_config[1] is not None:
block_config[1] = make_divisible(block_config[1] * width_mult)
block_config[2] = make_divisible(block_config[2] * width_mult)

(
first_conv,
blocks,
final_expand_layer,
feature_mix_layer,
classifier,
) = self.build_net_via_cfg(
cfg, input_channel, last_channel, n_classes, dropout_rate
)
super(MobileNetV3Large, self).__init__(
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
)
# set bn param
self.set_bn_param(*bn_param)

+ 415
- 0
xnas/spaces/OFA/MobileNetV3/ofa_cnn.py View File

@@ -0,0 +1,415 @@
import random
from copy import deepcopy

from xnas.spaces.OFA.dynamic_ops import DynamicMBConvLayer
from xnas.spaces.OFA.utils import val2list, make_divisible
from xnas.spaces.OFA.ops import (
ConvLayer,
IdentityLayer,
LinearLayer,
MBConvLayer,
ResidualBlock,
)
from xnas.spaces.OFA.MobileNetV3.cnn import MobileNetV3


__all__ = ["_OFAMobileNetV3", "OFAMobileNetV3"]


class OFAMobileNetV3(MobileNetV3):
def __init__(
self,
n_classes=1000,
bn_param=(0.1, 1e-5),
dropout_rate=0.1,
base_stage_width=None,
width_mult=1.0,
ks_list=3,
expand_ratio_list=6,
depth_list=4,
):

self.width_mult = width_mult
self.ks_list = val2list(ks_list, 1)
self.expand_ratio_list = val2list(expand_ratio_list, 1)
self.depth_list = val2list(depth_list, 1)

self.ks_list.sort()
self.expand_ratio_list.sort()
self.depth_list.sort()

base_stage_width = [16, 16, 24, 40, 80, 112, 160, 960, 1280]

final_expand_width = make_divisible(base_stage_width[-2] * self.width_mult)
last_channel = make_divisible(base_stage_width[-1] * self.width_mult)

stride_stages = [1, 2, 2, 2, 1, 2]
act_stages = ["relu", "relu", "relu", "h_swish", "h_swish", "h_swish"]
se_stages = [False, False, True, False, True, True]
n_block_list = [1] + [max(self.depth_list)] * 5
width_list = []
for base_width in base_stage_width[:-2]:
width = make_divisible(base_width * self.width_mult)
width_list.append(width)

input_channel, first_block_dim = width_list[0], width_list[1]
# first conv layer
first_conv = ConvLayer(
3, input_channel, kernel_size=3, stride=2, act_func="h_swish"
)
first_block_conv = MBConvLayer(
in_channels=input_channel,
out_channels=first_block_dim,
kernel_size=3,
stride=stride_stages[0],
expand_ratio=1,
act_func=act_stages[0],
use_se=se_stages[0],
)
first_block = ResidualBlock(
first_block_conv,
IdentityLayer(first_block_dim, first_block_dim)
if input_channel == first_block_dim
else None,
)

# inverted residual blocks
self.block_group_info = []
blocks = [first_block]
_block_index = 1
feature_dim = first_block_dim

for width, n_block, s, act_func, use_se in zip(
width_list[2:],
n_block_list[1:],
stride_stages[1:],
act_stages[1:],
se_stages[1:],
):
self.block_group_info.append([_block_index + i for i in range(n_block)])
_block_index += n_block

output_channel = width
for i in range(n_block):
if i == 0:
stride = s
else:
stride = 1
mobile_inverted_conv = DynamicMBConvLayer(
in_channel_list=val2list(feature_dim),
out_channel_list=val2list(output_channel),
kernel_size_list=ks_list,
expand_ratio_list=expand_ratio_list,
stride=stride,
act_func=act_func,
use_se=use_se,
)
if stride == 1 and feature_dim == output_channel:
shortcut = IdentityLayer(feature_dim, feature_dim)
else:
shortcut = None
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
feature_dim = output_channel
# final expand layer, feature mix layer & classifier
final_expand_layer = ConvLayer(
feature_dim, final_expand_width, kernel_size=1, act_func="h_swish"
)
feature_mix_layer = ConvLayer(
final_expand_width,
last_channel,
kernel_size=1,
bias=False,
use_bn=False,
act_func="h_swish",
)

classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)

super(OFAMobileNetV3, self).__init__(
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
)

# set bn param
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])

# runtime_depth
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]

""" WSConv_Network required methods """

@staticmethod
def name():
return "OFAMobileNetV3"

def forward(self, x):
# first conv
x = self.first_conv(x)
# first block
x = self.blocks[0](x)
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
x = self.blocks[idx](x)
x = self.final_expand_layer(x)
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
x = self.feature_mix_layer(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x

@property
def module_str(self):
_str = self.first_conv.module_str + "\n"
_str += self.blocks[0].module_str + "\n"

for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
_str += self.blocks[idx].module_str + "\n"

_str += self.final_expand_layer.module_str + "\n"
_str += self.feature_mix_layer.module_str + "\n"
_str += self.classifier.module_str + "\n"
return _str

@property
def config(self):
return {
"name": OFAMobileNetV3.__name__,
"bn": self.get_bn_param(),
"first_conv": self.first_conv.config,
"blocks": [block.config for block in self.blocks],
"final_expand_layer": self.final_expand_layer.config,
"feature_mix_layer": self.feature_mix_layer.config,
"classifier": self.classifier.config,
}

@staticmethod
def build_from_config(config):
raise ValueError("do not support this function")

@property
def grouped_block_index(self):
return self.block_group_info

def load_state_dict(self, state_dict, **kwargs):
model_dict = self.state_dict()
for key in state_dict:
if ".mobile_inverted_conv." in key:
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
else:
new_key = key
if new_key in model_dict:
pass
elif ".bn.bn." in new_key:
new_key = new_key.replace(".bn.bn.", ".bn.")
elif ".conv.conv.weight" in new_key:
new_key = new_key.replace(".conv.conv.weight", ".conv.weight")
elif ".linear.linear." in new_key:
new_key = new_key.replace(".linear.linear.", ".linear.")
##############################################################################
elif ".linear." in new_key:
new_key = new_key.replace(".linear.", ".linear.linear.")
elif "bn." in new_key:
new_key = new_key.replace("bn.", "bn.bn.")
elif "conv.weight" in new_key:
new_key = new_key.replace("conv.weight", "conv.conv.weight")
else:
raise ValueError(new_key)
assert new_key in model_dict, "%s" % new_key
model_dict[new_key] = state_dict[key]
super(OFAMobileNetV3, self).load_state_dict(model_dict)

""" set, sample and get active sub-networks """

def set_max_net(self):
self.set_active_subnet(
ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list)
)

def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
ks = val2list(ks, len(self.blocks) - 1)
expand_ratio = val2list(e, len(self.blocks) - 1)
depth = val2list(d, len(self.block_group_info))

for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
if k is not None:
block.conv.active_kernel_size = k
if e is not None:
block.conv.active_expand_ratio = e

for i, d in enumerate(depth):
if d is not None:
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)

def set_constraint(self, include_list, constraint_type="depth"):
if constraint_type == "depth":
self.__dict__["_depth_include_list"] = include_list.copy()
elif constraint_type == "expand_ratio":
self.__dict__["_expand_include_list"] = include_list.copy()
elif constraint_type == "kernel_size":
self.__dict__["_ks_include_list"] = include_list.copy()
else:
raise NotImplementedError

def clear_constraint(self):
self.__dict__["_depth_include_list"] = None
self.__dict__["_expand_include_list"] = None
self.__dict__["_ks_include_list"] = None

def sample_active_subnet(self):
ks_candidates = (
self.ks_list
if self.__dict__.get("_ks_include_list", None) is None
else self.__dict__["_ks_include_list"]
)
expand_candidates = (
self.expand_ratio_list
if self.__dict__.get("_expand_include_list", None) is None
else self.__dict__["_expand_include_list"]
)
depth_candidates = (
self.depth_list
if self.__dict__.get("_depth_include_list", None) is None
else self.__dict__["_depth_include_list"]
)

# sample kernel size
ks_setting = []
if not isinstance(ks_candidates[0], list):
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
for k_set in ks_candidates:
k = random.choice(k_set)
ks_setting.append(k)

# sample expand ratio
expand_setting = []
if not isinstance(expand_candidates[0], list):
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
for e_set in expand_candidates:
e = random.choice(e_set)
expand_setting.append(e)

# sample depth
depth_setting = []
if not isinstance(depth_candidates[0], list):
depth_candidates = [
depth_candidates for _ in range(len(self.block_group_info))
]
for d_set in depth_candidates:
d = random.choice(d_set)
depth_setting.append(d)

self.set_active_subnet(ks_setting, expand_setting, depth_setting)

return {
"ks": ks_setting,
"e": expand_setting,
"d": depth_setting,
}

def get_active_subnet(self, preserve_weight=True):
first_conv = deepcopy(self.first_conv)
blocks = [deepcopy(self.blocks[0])]

final_expand_layer = deepcopy(self.final_expand_layer)
feature_mix_layer = deepcopy(self.feature_mix_layer)
classifier = deepcopy(self.classifier)

input_channel = blocks[0].conv.out_channels
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append(
ResidualBlock(
self.blocks[idx].conv.get_active_subnet(
input_channel, preserve_weight
),
deepcopy(self.blocks[idx].shortcut),
)
)
input_channel = stage_blocks[-1].conv.out_channels
blocks += stage_blocks

_subnet = MobileNetV3(
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
)
_subnet.set_bn_param(**self.get_bn_param())
return _subnet

def get_active_net_config(self):
# first conv
first_conv_config = self.first_conv.config
first_block_config = self.blocks[0].config
final_expand_config = self.final_expand_layer.config
feature_mix_layer_config = self.feature_mix_layer.config
classifier_config = self.classifier.config

block_config_list = [first_block_config]
input_channel = first_block_config["conv"]["out_channels"]
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append(
{
"name": ResidualBlock.__name__,
"conv": self.blocks[idx].conv.get_active_subnet_config(
input_channel
),
"shortcut": self.blocks[idx].shortcut.config
if self.blocks[idx].shortcut is not None
else None,
}
)
input_channel = self.blocks[idx].conv.active_out_channel
block_config_list += stage_blocks

return {
"name": MobileNetV3.__name__,
"bn": self.get_bn_param(),
"first_conv": first_conv_config,
"blocks": block_config_list,
"final_expand_layer": final_expand_config,
"feature_mix_layer": feature_mix_layer_config,
"classifier": classifier_config,
}

""" Width Related Methods """

def re_organize_middle_weights(self, expand_ratio_stage=0):
for block in self.blocks[1:]:
block.conv.re_organize_middle_weights(expand_ratio_stage)


def _OFAMobileNetV3():
from xnas.core.config import cfg
width_mult_list = cfg.OFA.WIDTH_MULTI_LIST
ks_list = cfg.OFA.KS_LIST
expand_list = cfg.OFA.EXPAND_LIST
depth_list = cfg.OFA.DEPTH_LIST
width_mult_list = (
cfg.OFA.WIDTH_MULTI_LIST[0]
if len(cfg.OFA.WIDTH_MULTI_LIST) == 1
else cfg.OFA.WIDTH_MULTI_LIST
)
return OFAMobileNetV3(
n_classes=cfg.SEARCH.NUM_CLASSES,
bn_param=(0.1, 1e-5),
dropout_rate=0.1,
base_stage_width=None,
width_mult=width_mult_list,
ks_list=ks_list,
expand_ratio_list=expand_list,
depth_list=depth_list,
)

+ 242
- 0
xnas/spaces/OFA/ProxylessNet/cnn.py View File

@@ -0,0 +1,242 @@
import json
import torch.nn as nn

from xnas.spaces.OFA.ops import (
set_layer_from_config,
MBConvLayer,
ConvLayer,
IdentityLayer,
LinearLayer,
ResidualBlock,
GlobalAvgPool2d,
)
from xnas.spaces.OFA.utils import val2list, make_divisible
from xnas.spaces.OFA.MobileNetV3.cnn import WSConv_Network


__all__ = ["proxyless_base", "ProxylessNASNet", "MobileNetV2"]


def proxyless_base(
net_config=None,
n_classes=None,
bn_param=None,
dropout_rate=None,
):
assert net_config is not None, "Please input a network config"
net_config_json = json.load(open(net_config, "r"))

if n_classes is not None:
net_config_json["classifier"]["out_features"] = n_classes
if dropout_rate is not None:
net_config_json["classifier"]["dropout_rate"] = dropout_rate

net = ProxylessNASNet.build_from_config(net_config_json)
if bn_param is not None:
net.set_bn_param(*bn_param)

return net


class ProxylessNASNet(WSConv_Network):
def __init__(self, first_conv, blocks, feature_mix_layer, classifier):
super(ProxylessNASNet, self).__init__()

self.first_conv = first_conv
self.blocks = nn.ModuleList(blocks)
self.feature_mix_layer = feature_mix_layer
self.global_avg_pool = GlobalAvgPool2d(keep_dim=False)
self.classifier = classifier

def forward(self, x):
x = self.first_conv(x)
for block in self.blocks:
x = block(x)
if self.feature_mix_layer is not None:
x = self.feature_mix_layer(x)
x = self.global_avg_pool(x)
x = self.classifier(x)
return x

@property
def module_str(self):
_str = self.first_conv.module_str + "\n"
for block in self.blocks:
_str += block.module_str + "\n"
_str += self.feature_mix_layer.module_str + "\n"
_str += self.global_avg_pool.__repr__() + "\n"
_str += self.classifier.module_str
return _str

@property
def config(self):
return {
"name": ProxylessNASNet.__name__,
"bn": self.get_bn_param(),
"first_conv": self.first_conv.config,
"blocks": [block.config for block in self.blocks],
"feature_mix_layer": None
if self.feature_mix_layer is None
else self.feature_mix_layer.config,
"classifier": self.classifier.config,
}

@staticmethod
def build_from_config(config):
first_conv = set_layer_from_config(config["first_conv"])
feature_mix_layer = set_layer_from_config(config["feature_mix_layer"])
classifier = set_layer_from_config(config["classifier"])

blocks = []
for block_config in config["blocks"]:
blocks.append(ResidualBlock.build_from_config(block_config))

net = ProxylessNASNet(first_conv, blocks, feature_mix_layer, classifier)
if "bn" in config:
net.set_bn_param(**config["bn"])
else:
net.set_bn_param(momentum=0.1, eps=1e-3)

return net

def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, ResidualBlock):
if isinstance(m.conv, MBConvLayer) and isinstance(
m.shortcut, IdentityLayer
):
m.conv.point_linear.bn.weight.data.zero_()

@property
def grouped_block_index(self):
info_list = []
block_index_list = []
for i, block in enumerate(self.blocks[1:], 1):
if block.shortcut is None and len(block_index_list) > 0:
info_list.append(block_index_list)
block_index_list = []
block_index_list.append(i)
if len(block_index_list) > 0:
info_list.append(block_index_list)
return info_list

def load_state_dict(self, state_dict, **kwargs):
current_state_dict = self.state_dict()

for key in state_dict:
if key not in current_state_dict:
assert ".mobile_inverted_conv." in key
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
else:
new_key = key
current_state_dict[new_key] = state_dict[key]
super(ProxylessNASNet, self).load_state_dict(current_state_dict)


class MobileNetV2(ProxylessNASNet):
def __init__(
self,
n_classes=1000,
width_mult=1.0,
bn_param=(0.1, 1e-3),
dropout_rate=0.2,
ks=None,
expand_ratio=None,
depth_param=None,
stage_width_list=None,
):

ks = 3 if ks is None else ks
expand_ratio = 6 if expand_ratio is None else expand_ratio

input_channel = 32
last_channel = 1280

input_channel = make_divisible(input_channel * width_mult)
last_channel = (
make_divisible(last_channel * width_mult)
if width_mult > 1.0
else last_channel
)

inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[expand_ratio, 24, 2, 2],
[expand_ratio, 32, 3, 2],
[expand_ratio, 64, 4, 2],
[expand_ratio, 96, 3, 1],
[expand_ratio, 160, 3, 2],
[expand_ratio, 320, 1, 1],
]

if depth_param is not None:
assert isinstance(depth_param, int)
for i in range(1, len(inverted_residual_setting) - 1):
inverted_residual_setting[i][2] = depth_param

if stage_width_list is not None:
for i in range(len(inverted_residual_setting)):
inverted_residual_setting[i][1] = stage_width_list[i]

ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1)
_pt = 0

# first conv layer
first_conv = ConvLayer(
3,
input_channel,
kernel_size=3,
stride=2,
use_bn=True,
act_func="relu6",
ops_order="weight_bn_act",
)
# inverted residual blocks
blocks = []
for t, c, n, s in inverted_residual_setting:
output_channel = make_divisible(c * width_mult)
for i in range(n):
if i == 0:
stride = s
else:
stride = 1
if t == 1:
kernel_size = 3
else:
kernel_size = ks[_pt]
_pt += 1
mobile_inverted_conv = MBConvLayer(
in_channels=input_channel,
out_channels=output_channel,
kernel_size=kernel_size,
stride=stride,
expand_ratio=t,
)
if stride == 1:
if input_channel == output_channel:
shortcut = IdentityLayer(input_channel, input_channel)
else:
shortcut = None
else:
shortcut = None
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut))
input_channel = output_channel
# 1x1_conv before global average pooling
feature_mix_layer = ConvLayer(
input_channel,
last_channel,
kernel_size=1,
use_bn=True,
act_func="relu6",
ops_order="weight_bn_act",
)

classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)

super(MobileNetV2, self).__init__(
first_conv, blocks, feature_mix_layer, classifier
)

# set bn param
self.set_bn_param(*bn_param)

+ 390
- 0
xnas/spaces/OFA/ProxylessNet/ofa_cnn.py View File

@@ -0,0 +1,390 @@
import random
from copy import deepcopy

from xnas.spaces.OFA.dynamic_ops import DynamicMBConvLayer
from xnas.spaces.OFA.utils import val2list, make_divisible
from xnas.spaces.OFA.ops import (
ConvLayer,
IdentityLayer,
LinearLayer,
MBConvLayer,
ResidualBlock,
)
from xnas.spaces.OFA.ProxylessNet.cnn import ProxylessNASNet


__all__ = ["_OFAProxylessNASNet", "OFAProxylessNASNet"]



class OFAProxylessNASNet(ProxylessNASNet):
def __init__(
self,
n_classes=1000,
bn_param=(0.1, 1e-3),
dropout_rate=0.1,
base_stage_width=None,
width_mult=1.0,
ks_list=3,
expand_ratio_list=6,
depth_list=4,
):

self.width_mult = width_mult
self.ks_list = val2list(ks_list, 1)
self.expand_ratio_list = val2list(expand_ratio_list, 1)
self.depth_list = val2list(depth_list, 1)

self.ks_list.sort()
self.expand_ratio_list.sort()
self.depth_list.sort()

if base_stage_width == "google":
# MobileNetV2 Stage Width
base_stage_width = [32, 16, 24, 32, 64, 96, 160, 320, 1280]
else:
# ProxylessNAS Stage Width
base_stage_width = [32, 16, 24, 40, 80, 96, 192, 320, 1280]

input_channel = make_divisible(base_stage_width[0] * self.width_mult)
first_block_width = make_divisible(base_stage_width[1] * self.width_mult)
last_channel = make_divisible(base_stage_width[-1] * self.width_mult)

# first conv layer
first_conv = ConvLayer(
3,
input_channel,
kernel_size=3,
stride=2,
use_bn=True,
act_func="relu6",
ops_order="weight_bn_act",
)
# first block
first_block_conv = MBConvLayer(
in_channels=input_channel,
out_channels=first_block_width,
kernel_size=3,
stride=1,
expand_ratio=1,
act_func="relu6",
)
first_block = ResidualBlock(first_block_conv, None)

input_channel = first_block_width
# inverted residual blocks
self.block_group_info = []
blocks = [first_block]
_block_index = 1

stride_stages = [2, 2, 2, 1, 2, 1]
n_block_list = [max(self.depth_list)] * 5 + [1]

width_list = []
for base_width in base_stage_width[2:-1]:
width = make_divisible(base_width * self.width_mult)
width_list.append(width)

for width, n_block, s in zip(width_list, n_block_list, stride_stages):
self.block_group_info.append([_block_index + i for i in range(n_block)])
_block_index += n_block

output_channel = width
for i in range(n_block):
if i == 0:
stride = s
else:
stride = 1

mobile_inverted_conv = DynamicMBConvLayer(
in_channel_list=val2list(input_channel, 1),
out_channel_list=val2list(output_channel, 1),
kernel_size_list=ks_list,
expand_ratio_list=expand_ratio_list,
stride=stride,
act_func="relu6",
)

if stride == 1 and input_channel == output_channel:
shortcut = IdentityLayer(input_channel, input_channel)
else:
shortcut = None

mb_inverted_block = ResidualBlock(mobile_inverted_conv, shortcut)

blocks.append(mb_inverted_block)
input_channel = output_channel
# 1x1_conv before global average pooling
feature_mix_layer = ConvLayer(
input_channel,
last_channel,
kernel_size=1,
use_bn=True,
act_func="relu6",
)
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)

super(OFAProxylessNASNet, self).__init__(
first_conv, blocks, feature_mix_layer, classifier
)

# set bn param
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])

# runtime_depth
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]

""" WSConv_Network required methods """

@staticmethod
def name():
return "OFAProxylessNASNet"

def forward(self, x):
# first conv
x = self.first_conv(x)
# first block
x = self.blocks[0](x)

# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
x = self.blocks[idx](x)

# feature_mix_layer
x = self.feature_mix_layer(x)
x = x.mean(3).mean(2)

x = self.classifier(x)
return x

@property
def module_str(self):
_str = self.first_conv.module_str + "\n"
_str += self.blocks[0].module_str + "\n"

for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
_str += self.blocks[idx].module_str + "\n"
_str += self.feature_mix_layer.module_str + "\n"
_str += self.classifier.module_str + "\n"
return _str

@property
def config(self):
return {
"name": OFAProxylessNASNet.__name__,
"bn": self.get_bn_param(),
"first_conv": self.first_conv.config,
"blocks": [block.config for block in self.blocks],
"feature_mix_layer": None
if self.feature_mix_layer is None
else self.feature_mix_layer.config,
"classifier": self.classifier.config,
}

@staticmethod
def build_from_config(config):
raise ValueError("do not support this function")

@property
def grouped_block_index(self):
return self.block_group_info

def load_state_dict(self, state_dict, **kwargs):
model_dict = self.state_dict()
for key in state_dict:
if ".mobile_inverted_conv." in key:
new_key = key.replace(".mobile_inverted_conv.", ".conv.")
else:
new_key = key
if new_key in model_dict:
pass
elif ".bn.bn." in new_key:
new_key = new_key.replace(".bn.bn.", ".bn.")
elif ".conv.conv.weight" in new_key:
new_key = new_key.replace(".conv.conv.weight", ".conv.weight")
elif ".linear.linear." in new_key:
new_key = new_key.replace(".linear.linear.", ".linear.")
##############################################################################
elif ".linear." in new_key:
new_key = new_key.replace(".linear.", ".linear.linear.")
elif "bn." in new_key:
new_key = new_key.replace("bn.", "bn.bn.")
elif "conv.weight" in new_key:
new_key = new_key.replace("conv.weight", "conv.conv.weight")
else:
raise ValueError(new_key)
assert new_key in model_dict, "%s" % new_key
model_dict[new_key] = state_dict[key]
super(OFAProxylessNASNet, self).load_state_dict(model_dict)

""" set, sample and get active sub-networks """

def set_max_net(self):
self.set_active_subnet(
ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list)
)

def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
ks = val2list(ks, len(self.blocks) - 1)
expand_ratio = val2list(e, len(self.blocks) - 1)
depth = val2list(d, len(self.block_group_info))

for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
if k is not None:
block.conv.active_kernel_size = k
if e is not None:
block.conv.active_expand_ratio = e

for i, d in enumerate(depth):
if d is not None:
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)

def set_constraint(self, include_list, constraint_type="depth"):
if constraint_type == "depth":
self.__dict__["_depth_include_list"] = include_list.copy()
elif constraint_type == "expand_ratio":
self.__dict__["_expand_include_list"] = include_list.copy()
elif constraint_type == "kernel_size":
self.__dict__["_ks_include_list"] = include_list.copy()
else:
raise NotImplementedError

def clear_constraint(self):
self.__dict__["_depth_include_list"] = None
self.__dict__["_expand_include_list"] = None
self.__dict__["_ks_include_list"] = None

def sample_active_subnet(self):
ks_candidates = (
self.ks_list
if self.__dict__.get("_ks_include_list", None) is None
else self.__dict__["_ks_include_list"]
)
expand_candidates = (
self.expand_ratio_list
if self.__dict__.get("_expand_include_list", None) is None
else self.__dict__["_expand_include_list"]
)
depth_candidates = (
self.depth_list
if self.__dict__.get("_depth_include_list", None) is None
else self.__dict__["_depth_include_list"]
)

# sample kernel size
ks_setting = []
if not isinstance(ks_candidates[0], list):
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
for k_set in ks_candidates:
k = random.choice(k_set)
ks_setting.append(k)

# sample expand ratio
expand_setting = []
if not isinstance(expand_candidates[0], list):
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
for e_set in expand_candidates:
e = random.choice(e_set)
expand_setting.append(e)

# sample depth
depth_setting = []
if not isinstance(depth_candidates[0], list):
depth_candidates = [
depth_candidates for _ in range(len(self.block_group_info))
]
for d_set in depth_candidates:
d = random.choice(d_set)
depth_setting.append(d)

depth_setting[-1] = 1
self.set_active_subnet(ks_setting, expand_setting, depth_setting)

return {
"ks": ks_setting,
"e": expand_setting,
"d": depth_setting,
}

def get_active_subnet(self, preserve_weight=True):
first_conv = deepcopy(self.first_conv)
blocks = [deepcopy(self.blocks[0])]
feature_mix_layer = deepcopy(self.feature_mix_layer)
classifier = deepcopy(self.classifier)

input_channel = blocks[0].conv.out_channels
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append(
ResidualBlock(
self.blocks[idx].conv.get_active_subnet(
input_channel, preserve_weight
),
deepcopy(self.blocks[idx].shortcut),
)
)
input_channel = stage_blocks[-1].conv.out_channels
blocks += stage_blocks

_subnet = ProxylessNASNet(first_conv, blocks, feature_mix_layer, classifier)
_subnet.set_bn_param(**self.get_bn_param())
return _subnet

def get_active_net_config(self):
first_conv_config = self.first_conv.config
first_block_config = self.blocks[0].config
feature_mix_layer_config = self.feature_mix_layer.config
classifier_config = self.classifier.config

block_config_list = [first_block_config]
input_channel = first_block_config["conv"]["out_channels"]
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append(
{
"name": ResidualBlock.__name__,
"conv": self.blocks[idx].conv.get_active_subnet_config(
input_channel
),
"shortcut": self.blocks[idx].shortcut.config
if self.blocks[idx].shortcut is not None
else None,
}
)
try:
input_channel = self.blocks[idx].conv.active_out_channel
except Exception:
input_channel = self.blocks[idx].conv.out_channels
block_config_list += stage_blocks

return {
"name": ProxylessNASNet.__name__,
"bn": self.get_bn_param(),
"first_conv": first_conv_config,
"blocks": block_config_list,
"feature_mix_layer": feature_mix_layer_config,
"classifier": classifier_config,
}

""" Width Related Methods """

def re_organize_middle_weights(self, expand_ratio_stage=0):
for block in self.blocks[1:]:
block.conv.re_organize_middle_weights(expand_ratio_stage)


def _OFAProxylessNASNet():
return OFAProxylessNASNet()

+ 248
- 0
xnas/spaces/OFA/ResNets/cnn.py View File

@@ -0,0 +1,248 @@
import torch.nn as nn

from xnas.spaces.OFA.utils import make_divisible
from xnas.spaces.OFA.ops import (
set_layer_from_config,
ConvLayer,
IdentityLayer,
LinearLayer,
ResidualBlock,
ResNetBottleneckBlock,
GlobalAvgPool2d
)
from xnas.spaces.OFA.MobileNetV3.cnn import WSConv_Network


__all__ = ["ResNet", "ResNet50", "ResNet50D"]


class ResNet(WSConv_Network):

BASE_DEPTH_LIST = [2, 2, 4, 2]
STAGE_WIDTH_LIST = [256, 512, 1024, 2048]

def __init__(self, input_stem, blocks, classifier):
super(ResNet, self).__init__()

self.input_stem = nn.ModuleList(input_stem)
self.max_pooling = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
)
self.blocks = nn.ModuleList(blocks)
self.global_avg_pool = GlobalAvgPool2d(keep_dim=False)
self.classifier = classifier

def forward(self, x):
for layer in self.input_stem:
x = layer(x)
x = self.max_pooling(x)
for block in self.blocks:
x = block(x)
x = self.global_avg_pool(x)
x = self.classifier(x)
return x

@property
def module_str(self):
_str = ""
for layer in self.input_stem:
_str += layer.module_str + "\n"
_str += "max_pooling(ks=3, stride=2)\n"
for block in self.blocks:
_str += block.module_str + "\n"
_str += self.global_avg_pool.__repr__() + "\n"
_str += self.classifier.module_str
return _str

@property
def config(self):
return {
"name": ResNet.__name__,
"bn": self.get_bn_param(),
"input_stem": [layer.config for layer in self.input_stem],
"blocks": [block.config for block in self.blocks],
"classifier": self.classifier.config,
}

@staticmethod
def build_from_config(config):
classifier = set_layer_from_config(config["classifier"])

input_stem = []
for layer_config in config["input_stem"]:
input_stem.append(set_layer_from_config(layer_config))
blocks = []
for block_config in config["blocks"]:
blocks.append(set_layer_from_config(block_config))

net = ResNet(input_stem, blocks, classifier)
if "bn" in config:
net.set_bn_param(**config["bn"])
else:
net.set_bn_param(momentum=0.1, eps=1e-5)

return net

def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, ResNetBottleneckBlock) and isinstance(
m.downsample, IdentityLayer
):
m.conv3.bn.weight.data.zero_()

@property
def grouped_block_index(self):
info_list = []
block_index_list = []
for i, block in enumerate(self.blocks):
if (
not isinstance(block.downsample, IdentityLayer)
and len(block_index_list) > 0
):
info_list.append(block_index_list)
block_index_list = []
block_index_list.append(i)
if len(block_index_list) > 0:
info_list.append(block_index_list)
return info_list

def load_state_dict(self, state_dict, **kwargs):
super(ResNet, self).load_state_dict(state_dict)


class ResNet50(ResNet):
def __init__(
self,
n_classes=1000,
width_mult=1.0,
bn_param=(0.1, 1e-5),
dropout_rate=0,
expand_ratio=None,
depth_param=None,
):

expand_ratio = 0.25 if expand_ratio is None else expand_ratio

input_channel = make_divisible(64 * width_mult)
stage_width_list = ResNet.STAGE_WIDTH_LIST.copy()
for i, width in enumerate(stage_width_list):
stage_width_list[i] = make_divisible(width * width_mult)

depth_list = [3, 4, 6, 3]
if depth_param is not None:
for i, depth in enumerate(ResNet.BASE_DEPTH_LIST):
depth_list[i] = depth + depth_param

stride_list = [1, 2, 2, 2]

# build input stem
input_stem = [
ConvLayer(
3,
input_channel,
kernel_size=7,
stride=2,
use_bn=True,
act_func="relu",
ops_order="weight_bn_act",
)
]

# blocks
blocks = []
for d, width, s in zip(depth_list, stage_width_list, stride_list):
for i in range(d):
stride = s if i == 0 else 1
bottleneck_block = ResNetBottleneckBlock(
input_channel,
width,
kernel_size=3,
stride=stride,
expand_ratio=expand_ratio,
act_func="relu",
downsample_mode="conv",
)
blocks.append(bottleneck_block)
input_channel = width
# classifier
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)

super(ResNet50, self).__init__(input_stem, blocks, classifier)

# set bn param
self.set_bn_param(*bn_param)


class ResNet50D(ResNet):
def __init__(
self,
n_classes=1000,
width_mult=1.0,
bn_param=(0.1, 1e-5),
dropout_rate=0,
expand_ratio=None,
depth_param=None,
):

expand_ratio = 0.25 if expand_ratio is None else expand_ratio

input_channel = make_divisible(64 * width_mult)
mid_input_channel = make_divisible(input_channel // 2)
stage_width_list = ResNet.STAGE_WIDTH_LIST.copy()
for i, width in enumerate(stage_width_list):
stage_width_list[i] = make_divisible(width * width_mult)

depth_list = [3, 4, 6, 3]
if depth_param is not None:
for i, depth in enumerate(ResNet.BASE_DEPTH_LIST):
depth_list[i] = depth + depth_param

stride_list = [1, 2, 2, 2]

# build input stem
input_stem = [
ConvLayer(3, mid_input_channel, 3, stride=2, use_bn=True, act_func="relu"),
ResidualBlock(
ConvLayer(
mid_input_channel,
mid_input_channel,
3,
stride=1,
use_bn=True,
act_func="relu",
),
IdentityLayer(mid_input_channel, mid_input_channel),
),
ConvLayer(
mid_input_channel,
input_channel,
3,
stride=1,
use_bn=True,
act_func="relu",
),
]

# blocks
blocks = []
for d, width, s in zip(depth_list, stage_width_list, stride_list):
for i in range(d):
stride = s if i == 0 else 1
bottleneck_block = ResNetBottleneckBlock(
input_channel,
width,
kernel_size=3,
stride=stride,
expand_ratio=expand_ratio,
act_func="relu",
downsample_mode="avgpool_conv",
)
blocks.append(bottleneck_block)
input_channel = width
# classifier
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)

super(ResNet50D, self).__init__(input_stem, blocks, classifier)

# set bn param
self.set_bn_param(*bn_param)

+ 353
- 0
xnas/spaces/OFA/ResNets/ofa_cnn.py View File

@@ -0,0 +1,353 @@
import random

from xnas.spaces.OFA.utils import val2list, make_divisible
from xnas.spaces.OFA.dynamic_ops import (
DynamicConvLayer,
DynamicLinearLayer,
DynamicResNetBottleneckBlock,
)
from xnas.spaces.OFA.ops import (
IdentityLayer,
ResidualBlock,
)
from xnas.spaces.OFA.ResNets.cnn import ResNet


__all__ = ["_OFAResNet", "OFAResNet"]


class OFAResNet(ResNet):
def __init__(
self,
n_classes=1000,
bn_param=(0.1, 1e-5),
dropout_rate=0,
depth_list=2,
expand_ratio_list=0.25,
width_mult_list=1.0,
):

self.depth_list = val2list(depth_list)
self.expand_ratio_list = val2list(expand_ratio_list)
self.width_mult_list = val2list(width_mult_list)
# sort
self.depth_list.sort()
self.expand_ratio_list.sort()
self.width_mult_list.sort()

input_channel = [
make_divisible(64 * width_mult)
for width_mult in self.width_mult_list
]
mid_input_channel = [
make_divisible(channel // 2)
for channel in input_channel
]

stage_width_list = ResNet.STAGE_WIDTH_LIST.copy()
for i, width in enumerate(stage_width_list):
stage_width_list[i] = [
make_divisible(width * width_mult)
for width_mult in self.width_mult_list
]

n_block_list = [
base_depth + max(self.depth_list) for base_depth in ResNet.BASE_DEPTH_LIST
]
stride_list = [1, 2, 2, 2]

# build input stem
input_stem = [
DynamicConvLayer(
val2list(3),
mid_input_channel,
3,
stride=2,
use_bn=True,
act_func="relu",
),
ResidualBlock(
DynamicConvLayer(
mid_input_channel,
mid_input_channel,
3,
stride=1,
use_bn=True,
act_func="relu",
),
IdentityLayer(mid_input_channel, mid_input_channel),
),
DynamicConvLayer(
mid_input_channel,
input_channel,
3,
stride=1,
use_bn=True,
act_func="relu",
),
]

# blocks
blocks = []
for d, width, s in zip(n_block_list, stage_width_list, stride_list):
for i in range(d):
stride = s if i == 0 else 1
bottleneck_block = DynamicResNetBottleneckBlock(
input_channel,
width,
expand_ratio_list=self.expand_ratio_list,
kernel_size=3,
stride=stride,
act_func="relu",
downsample_mode="avgpool_conv",
)
blocks.append(bottleneck_block)
input_channel = width
# classifier
classifier = DynamicLinearLayer(
input_channel, n_classes, dropout_rate=dropout_rate
)

super(OFAResNet, self).__init__(input_stem, blocks, classifier)

# set bn param
self.set_bn_param(*bn_param)

# runtime_depth
self.input_stem_skipping = 0
self.runtime_depth = [0] * len(n_block_list)

@property
def ks_list(self):
return [3]

@staticmethod
def name():
return "OFAResNet"

def forward(self, x):
for layer in self.input_stem:
if (
self.input_stem_skipping > 0
and isinstance(layer, ResidualBlock)
and isinstance(layer.shortcut, IdentityLayer)
):
pass
else:
x = layer(x)
x = self.max_pooling(x)
for stage_id, block_idx in enumerate(self.grouped_block_index):
depth_param = self.runtime_depth[stage_id]
active_idx = block_idx[: len(block_idx) - depth_param]
for idx in active_idx:
x = self.blocks[idx](x)
x = self.global_avg_pool(x)
x = self.classifier(x)
return x

@property
def module_str(self):
_str = ""
for layer in self.input_stem:
if (
self.input_stem_skipping > 0
and isinstance(layer, ResidualBlock)
and isinstance(layer.shortcut, IdentityLayer)
):
pass
else:
_str += layer.module_str + "\n"
_str += "max_pooling(ks=3, stride=2)\n"
for stage_id, block_idx in enumerate(self.grouped_block_index):
depth_param = self.runtime_depth[stage_id]
active_idx = block_idx[: len(block_idx) - depth_param]
for idx in active_idx:
_str += self.blocks[idx].module_str + "\n"
_str += self.global_avg_pool.__repr__() + "\n"
_str += self.classifier.module_str
return _str

@property
def config(self):
return {
"name": OFAResNet.__name__,
"bn": self.get_bn_param(),
"input_stem": [layer.config for layer in self.input_stem],
"blocks": [block.config for block in self.blocks],
"classifier": self.classifier.config,
}

@staticmethod
def build_from_config(config):
raise ValueError("do not support this function")

def load_state_dict(self, state_dict, **kwargs):
model_dict = self.state_dict()
for key in state_dict:
new_key = key
if new_key in model_dict:
pass
elif ".linear." in new_key:
new_key = new_key.replace(".linear.", ".linear.linear.")
elif "bn." in new_key:
new_key = new_key.replace("bn.", "bn.bn.")
elif "conv.weight" in new_key:
new_key = new_key.replace("conv.weight", "conv.conv.weight")
else:
raise ValueError(new_key)
assert new_key in model_dict, "%s" % new_key
model_dict[new_key] = state_dict[key]
super(OFAResNet, self).load_state_dict(model_dict)

""" set, sample and get active sub-networks """

def set_max_net(self):
self.set_active_subnet(
d=max(self.depth_list),
e=max(self.expand_ratio_list),
w=len(self.width_mult_list) - 1,
)

def set_active_subnet(self, d=None, e=None, w=None, **kwargs):
depth = val2list(d, len(ResNet.BASE_DEPTH_LIST) + 1)
expand_ratio = val2list(e, len(self.blocks))
width_mult = val2list(w, len(ResNet.BASE_DEPTH_LIST) + 2)

for block, e in zip(self.blocks, expand_ratio):
if e is not None:
block.active_expand_ratio = e

if width_mult[0] is not None:
self.input_stem[1].conv.active_out_channel = self.input_stem[
0
].active_out_channel = self.input_stem[0].out_channel_list[width_mult[0]]
if width_mult[1] is not None:
self.input_stem[2].active_out_channel = self.input_stem[2].out_channel_list[
width_mult[1]
]

if depth[0] is not None:
self.input_stem_skipping = depth[0] != max(self.depth_list)
for stage_id, (block_idx, d, w) in enumerate(
zip(self.grouped_block_index, depth[1:], width_mult[2:])
):
if d is not None:
self.runtime_depth[stage_id] = max(self.depth_list) - d
if w is not None:
for idx in block_idx:
self.blocks[idx].active_out_channel = self.blocks[
idx
].out_channel_list[w]

def sample_active_subnet(self):
# sample expand ratio
expand_setting = []
for block in self.blocks:
expand_setting.append(random.choice(block.expand_ratio_list))

# sample depth
depth_setting = [random.choice([max(self.depth_list), min(self.depth_list)])]
for stage_id in range(len(ResNet.BASE_DEPTH_LIST)):
depth_setting.append(random.choice(self.depth_list))

# sample width_mult
width_mult_setting = [
random.choice(list(range(len(self.input_stem[0].out_channel_list)))),
random.choice(list(range(len(self.input_stem[2].out_channel_list)))),
]
for stage_id, block_idx in enumerate(self.grouped_block_index):
stage_first_block = self.blocks[block_idx[0]]
width_mult_setting.append(
random.choice(list(range(len(stage_first_block.out_channel_list))))
)

arch_config = {"d": depth_setting, "e": expand_setting, "w": width_mult_setting}
self.set_active_subnet(**arch_config)
return arch_config

def get_active_subnet(self, preserve_weight=True):
input_stem = [self.input_stem[0].get_active_subnet(3, preserve_weight)]
if self.input_stem_skipping <= 0:
input_stem.append(
ResidualBlock(
self.input_stem[1].conv.get_active_subnet(
self.input_stem[0].active_out_channel, preserve_weight
),
IdentityLayer(
self.input_stem[0].active_out_channel,
self.input_stem[0].active_out_channel,
),
)
)
input_stem.append(
self.input_stem[2].get_active_subnet(
self.input_stem[0].active_out_channel, preserve_weight
)
)
input_channel = self.input_stem[2].active_out_channel

blocks = []
for stage_id, block_idx in enumerate(self.grouped_block_index):
depth_param = self.runtime_depth[stage_id]
active_idx = block_idx[: len(block_idx) - depth_param]
for idx in active_idx:
blocks.append(
self.blocks[idx].get_active_subnet(input_channel, preserve_weight)
)
input_channel = self.blocks[idx].active_out_channel
classifier = self.classifier.get_active_subnet(input_channel, preserve_weight)
subnet = ResNet(input_stem, blocks, classifier)

subnet.set_bn_param(**self.get_bn_param())
return subnet

def get_active_net_config(self):
input_stem_config = [self.input_stem[0].get_active_subnet_config(3)]
if self.input_stem_skipping <= 0:
input_stem_config.append(
{
"name": ResidualBlock.__name__,
"conv": self.input_stem[1].conv.get_active_subnet_config(
self.input_stem[0].active_out_channel
),
"shortcut": IdentityLayer(
self.input_stem[0].active_out_channel,
self.input_stem[0].active_out_channel,
),
}
)
input_stem_config.append(
self.input_stem[2].get_active_subnet_config(
self.input_stem[0].active_out_channel
)
)
input_channel = self.input_stem[2].active_out_channel

blocks_config = []
for stage_id, block_idx in enumerate(self.grouped_block_index):
depth_param = self.runtime_depth[stage_id]
active_idx = block_idx[: len(block_idx) - depth_param]
for idx in active_idx:
blocks_config.append(
self.blocks[idx].get_active_subnet_config(input_channel)
)
input_channel = self.blocks[idx].active_out_channel
classifier_config = self.classifier.get_active_subnet_config(input_channel)
return {
"name": ResNet.__name__,
"bn": self.get_bn_param(),
"input_stem": input_stem_config,
"blocks": blocks_config,
"classifier": classifier_config,
}

""" Width Related Methods """

def re_organize_middle_weights(self, expand_ratio_stage=0):
for block in self.blocks:
block.re_organize_middle_weights(expand_ratio_stage)


def _OFAResNet():
return OFAResNet()


+ 1182
- 0
xnas/spaces/OFA/dynamic_ops.py View File

@@ -0,0 +1,1182 @@
from copy import deepcopy
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from xnas.spaces.OFA.utils import (
get_same_padding,
val2list,
make_divisible,
)
from xnas.spaces.OFA.ops import (
build_activation,
set_layer_from_config,
SEModule,
WeightStandardConv2d,
MBConvLayer,
ConvLayer,
IdentityLayer,
ResNetBottleneckBlock,
LinearLayer,
)


class DynamicSeparableConv2d(nn.Module):
KERNEL_TRANSFORM_MODE = 1 # None or 1

def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1):
super(DynamicSeparableConv2d, self).__init__()

self.max_in_channels = max_in_channels
self.kernel_size_list = kernel_size_list
self.stride = stride
self.dilation = dilation

self.conv = nn.Conv2d(
self.max_in_channels,
self.max_in_channels,
max(self.kernel_size_list),
self.stride,
groups=self.max_in_channels,
bias=False,
)

self._ks_set = list(set(self.kernel_size_list))
self._ks_set.sort() # e.g., [3, 5, 7]
if self.KERNEL_TRANSFORM_MODE is not None:
# register scaling parameters
# 7to5_matrix, 5to3_matrix
scale_params = {}
for i in range(len(self._ks_set) - 1):
ks_small = self._ks_set[i]
ks_larger = self._ks_set[i + 1]
param_name = "%dto%d" % (ks_larger, ks_small)
# noinspection PyArgumentList
scale_params["%s_matrix" % param_name] = Parameter(
torch.eye(ks_small ** 2)
)
for name, param in scale_params.items():
self.register_parameter(name, param)

self.active_kernel_size = max(self.kernel_size_list)

def get_active_filter(self, in_channel, kernel_size):
out_channel = in_channel
max_kernel_size = max(self.kernel_size_list)

start, end = sub_filter_start_end(max_kernel_size, kernel_size)
filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end]
if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size:
start_filter = self.conv.weight[
:out_channel, :in_channel, :, :
] # start with max kernel
for i in range(len(self._ks_set) - 1, 0, -1):
src_ks = self._ks_set[i]
if src_ks <= kernel_size:
break
target_ks = self._ks_set[i - 1]
start, end = sub_filter_start_end(src_ks, target_ks)
_input_filter = start_filter[:, :, start:end, start:end]
_input_filter = _input_filter.contiguous()
_input_filter = _input_filter.view(
_input_filter.size(0), _input_filter.size(1), -1
)
_input_filter = _input_filter.view(-1, _input_filter.size(2))
_input_filter = F.linear(
_input_filter,
self.__getattr__("%dto%d_matrix" % (src_ks, target_ks)),
)
_input_filter = _input_filter.view(
filters.size(0), filters.size(1), target_ks ** 2
)
_input_filter = _input_filter.view(
filters.size(0), filters.size(1), target_ks, target_ks
)
start_filter = _input_filter
filters = start_filter
return filters

def forward(self, x, kernel_size=None):
if kernel_size is None:
kernel_size = self.active_kernel_size
in_channel = x.size(1)

filters = self.get_active_filter(in_channel, kernel_size).contiguous()

padding = get_same_padding(kernel_size)
filters = (
self.conv.weight_standardization(filters)
if isinstance(self.conv, WeightStandardConv2d)
else filters
)
y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, in_channel)
return y


class DynamicConv2d(nn.Module):
def __init__(
self, max_in_channels, max_out_channels, kernel_size=1, stride=1, dilation=1
):
super(DynamicConv2d, self).__init__()

self.max_in_channels = max_in_channels
self.max_out_channels = max_out_channels
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation

self.conv = nn.Conv2d(
self.max_in_channels,
self.max_out_channels,
self.kernel_size,
stride=self.stride,
bias=False,
)

self.active_out_channel = self.max_out_channels

def get_active_filter(self, out_channel, in_channel):
return self.conv.weight[:out_channel, :in_channel, :, :]

def forward(self, x, out_channel=None):
if out_channel is None:
out_channel = self.active_out_channel
in_channel = x.size(1)
filters = self.get_active_filter(out_channel, in_channel).contiguous()

padding = get_same_padding(self.kernel_size)
filters = (
self.conv.weight_standardization(filters)
if isinstance(self.conv, WeightStandardConv2d)
else filters
)
y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, 1)
return y


class DynamicGroupConv2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size_list,
groups_list,
stride=1,
dilation=1,
):
super(DynamicGroupConv2d, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size_list = kernel_size_list
self.groups_list = groups_list
self.stride = stride
self.dilation = dilation

self.conv = nn.Conv2d(
self.in_channels,
self.out_channels,
max(self.kernel_size_list),
self.stride,
groups=min(self.groups_list),
bias=False,
)

self.active_kernel_size = max(self.kernel_size_list)
self.active_groups = min(self.groups_list)

def get_active_filter(self, kernel_size, groups):
start, end = sub_filter_start_end(max(self.kernel_size_list), kernel_size)
filters = self.conv.weight[:, :, start:end, start:end]

sub_filters = torch.chunk(filters, groups, dim=0)
sub_in_channels = self.in_channels // groups
sub_ratio = filters.size(1) // sub_in_channels

filter_crops = []
for i, sub_filter in enumerate(sub_filters):
part_id = i % sub_ratio
start = part_id * sub_in_channels
filter_crops.append(sub_filter[:, start : start + sub_in_channels, :, :])
filters = torch.cat(filter_crops, dim=0)
return filters

def forward(self, x, kernel_size=None, groups=None):
if kernel_size is None:
kernel_size = self.active_kernel_size
if groups is None:
groups = self.active_groups

filters = self.get_active_filter(kernel_size, groups).contiguous()
padding = get_same_padding(kernel_size)
filters = (
self.conv.weight_standardization(filters)
if isinstance(self.conv, WeightStandardConv2d)
else filters
)
y = F.conv2d(
x,
filters,
None,
self.stride,
padding,
self.dilation,
groups,
)
return y


class DynamicBatchNorm2d(nn.Module):
SET_RUNNING_STATISTICS = False

def __init__(self, max_feature_dim):
super(DynamicBatchNorm2d, self).__init__()

self.max_feature_dim = max_feature_dim
self.bn = nn.BatchNorm2d(self.max_feature_dim)

@staticmethod
def bn_forward(x, bn: nn.BatchNorm2d, feature_dim):
if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS:
return bn(x)
else:
exponential_average_factor = 0.0

if bn.training and bn.track_running_stats:
if bn.num_batches_tracked is not None:
bn.num_batches_tracked += 1
if bn.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(bn.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = bn.momentum
return F.batch_norm(
x,
bn.running_mean[:feature_dim],
bn.running_var[:feature_dim],
bn.weight[:feature_dim],
bn.bias[:feature_dim],
bn.training or not bn.track_running_stats,
exponential_average_factor,
bn.eps,
)

def forward(self, x):
feature_dim = x.size(1)
y = self.bn_forward(x, self.bn, feature_dim)
return y


class DynamicGroupNorm(nn.GroupNorm):
def __init__(
self, num_groups, num_channels, eps=1e-5, affine=True, channel_per_group=None
):
super(DynamicGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
self.channel_per_group = channel_per_group

def forward(self, x):
n_channels = x.size(1)
n_groups = n_channels // self.channel_per_group
return F.group_norm(
x, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps
)

@property
def bn(self):
return self


class DynamicSE(SEModule):
def __init__(self, max_channel):
super(DynamicSE, self).__init__(max_channel)

def get_active_reduce_weight(self, num_mid, in_channel, groups=None):
if groups is None or groups == 1:
return self.fc.reduce.weight[:num_mid, :in_channel, :, :]
else:
assert in_channel % groups == 0
sub_in_channels = in_channel // groups
sub_filters = torch.chunk(
self.fc.reduce.weight[:num_mid, :, :, :], groups, dim=1
)
return torch.cat(
[sub_filter[:, :sub_in_channels, :, :] for sub_filter in sub_filters],
dim=1,
)

def get_active_reduce_bias(self, num_mid):
return (
self.fc.reduce.bias[:num_mid] if self.fc.reduce.bias is not None else None
)

def get_active_expand_weight(self, num_mid, in_channel, groups=None):
if groups is None or groups == 1:
return self.fc.expand.weight[:in_channel, :num_mid, :, :]
else:
assert in_channel % groups == 0
sub_in_channels = in_channel // groups
sub_filters = torch.chunk(
self.fc.expand.weight[:, :num_mid, :, :], groups, dim=0
)
return torch.cat(
[sub_filter[:sub_in_channels, :, :, :] for sub_filter in sub_filters],
dim=0,
)

def get_active_expand_bias(self, in_channel, groups=None):
if groups is None or groups == 1:
return (
self.fc.expand.bias[:in_channel]
if self.fc.expand.bias is not None
else None
)
else:
assert in_channel % groups == 0
sub_in_channels = in_channel // groups
sub_bias_list = torch.chunk(self.fc.expand.bias, groups, dim=0)
return torch.cat(
[sub_bias[:sub_in_channels] for sub_bias in sub_bias_list], dim=0
)

def forward(self, x, groups=None):
in_channel = x.size(1)
num_mid = make_divisible(in_channel // self.reduction)

y = x.mean(3, keepdim=True).mean(2, keepdim=True)
# reduce
reduce_filter = self.get_active_reduce_weight(
num_mid, in_channel, groups=groups
).contiguous()
reduce_bias = self.get_active_reduce_bias(num_mid)
y = F.conv2d(y, reduce_filter, reduce_bias, 1, 0, 1, 1)
# relu
y = self.fc.relu(y)
# expand
expand_filter = self.get_active_expand_weight(
num_mid, in_channel, groups=groups
).contiguous()
expand_bias = self.get_active_expand_bias(in_channel, groups=groups)
y = F.conv2d(y, expand_filter, expand_bias, 1, 0, 1, 1)
# hard sigmoid
y = self.fc.h_sigmoid(y)

return x * y


class DynamicLinear(nn.Module):
def __init__(self, max_in_features, max_out_features, bias=True):
super(DynamicLinear, self).__init__()

self.max_in_features = max_in_features
self.max_out_features = max_out_features
self.bias = bias

self.linear = nn.Linear(self.max_in_features, self.max_out_features, self.bias)

self.active_out_features = self.max_out_features

def get_active_weight(self, out_features, in_features):
return self.linear.weight[:out_features, :in_features]

def get_active_bias(self, out_features):
return self.linear.bias[:out_features] if self.bias else None

def forward(self, x, out_features=None):
if out_features is None:
out_features = self.active_out_features

in_features = x.size(1)
weight = self.get_active_weight(out_features, in_features).contiguous()
bias = self.get_active_bias(out_features)
y = F.linear(x, weight, bias)
return y


class DynamicLinearLayer(nn.Module):
def __init__(self, in_features_list, out_features, bias=True, dropout_rate=0):
super(DynamicLinearLayer, self).__init__()

self.in_features_list = in_features_list
self.out_features = out_features
self.bias = bias
self.dropout_rate = dropout_rate

if self.dropout_rate > 0:
self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
else:
self.dropout = None
self.linear = DynamicLinear(
max_in_features=max(self.in_features_list),
max_out_features=self.out_features,
bias=self.bias,
)

def forward(self, x):
if self.dropout is not None:
x = self.dropout(x)
return self.linear(x)

@property
def module_str(self):
return "DyLinear(%d, %d)" % (max(self.in_features_list), self.out_features)

@property
def config(self):
return {
"name": DynamicLinear.__name__,
"in_features_list": self.in_features_list,
"out_features": self.out_features,
"bias": self.bias,
"dropout_rate": self.dropout_rate,
}

@staticmethod
def build_from_config(config):
return DynamicLinearLayer(**config)

def get_active_subnet(self, in_features, preserve_weight=True):
sub_layer = LinearLayer(
in_features, self.out_features, self.bias, dropout_rate=self.dropout_rate
)
sub_layer = sub_layer.to(self.parameters().__next__().device)
if not preserve_weight:
return sub_layer

sub_layer.linear.weight.data.copy_(
self.linear.get_active_weight(self.out_features, in_features).data
)
if self.bias:
sub_layer.linear.bias.data.copy_(
self.linear.get_active_bias(self.out_features).data
)
return sub_layer

def get_active_subnet_config(self, in_features):
return {
"name": LinearLayer.__name__,
"in_features": in_features,
"out_features": self.out_features,
"bias": self.bias,
"dropout_rate": self.dropout_rate,
}


class DynamicMBConvLayer(nn.Module):
def __init__(
self,
in_channel_list,
out_channel_list,
kernel_size_list=3,
expand_ratio_list=6,
stride=1,
act_func="relu6",
use_se=False,
):
super(DynamicMBConvLayer, self).__init__()

self.in_channel_list = in_channel_list
self.out_channel_list = out_channel_list

self.kernel_size_list = val2list(kernel_size_list)
self.expand_ratio_list = val2list(expand_ratio_list)

self.stride = stride
self.act_func = act_func
self.use_se = use_se

# build modules
max_middle_channel = make_divisible(
round(max(self.in_channel_list) * max(self.expand_ratio_list)))
if max(self.expand_ratio_list) == 1:
self.inverted_bottleneck = None
else:
self.inverted_bottleneck = nn.Sequential(
OrderedDict(
[
(
"conv",
DynamicConv2d(
max(self.in_channel_list), max_middle_channel
),
),
("bn", DynamicBatchNorm2d(max_middle_channel)),
("act", build_activation(self.act_func)),
]
)
)

self.depth_conv = nn.Sequential(
OrderedDict(
[
(
"conv",
DynamicSeparableConv2d(
max_middle_channel, self.kernel_size_list, self.stride
),
),
("bn", DynamicBatchNorm2d(max_middle_channel)),
("act", build_activation(self.act_func)),
]
)
)
if self.use_se:
self.depth_conv.add_module("se", DynamicSE(max_middle_channel))

self.point_linear = nn.Sequential(
OrderedDict(
[
(
"conv",
DynamicConv2d(max_middle_channel, max(self.out_channel_list)),
),
("bn", DynamicBatchNorm2d(max(self.out_channel_list))),
]
)
)

self.active_kernel_size = max(self.kernel_size_list)
self.active_expand_ratio = max(self.expand_ratio_list)
self.active_out_channel = max(self.out_channel_list)

def forward(self, x):
in_channel = x.size(1)

if self.inverted_bottleneck is not None:
self.inverted_bottleneck.conv.active_out_channel = make_divisible(
round(in_channel * self.active_expand_ratio))

self.depth_conv.conv.active_kernel_size = self.active_kernel_size
self.point_linear.conv.active_out_channel = self.active_out_channel

if self.inverted_bottleneck is not None:
x = self.inverted_bottleneck(x)
x = self.depth_conv(x)
x = self.point_linear(x)
return x

@property
def module_str(self):
if self.use_se:
return "SE(O%d, E%.1f, K%d)" % (
self.active_out_channel,
self.active_expand_ratio,
self.active_kernel_size,
)
else:
return "(O%d, E%.1f, K%d)" % (
self.active_out_channel,
self.active_expand_ratio,
self.active_kernel_size,
)

@property
def config(self):
return {
"name": DynamicMBConvLayer.__name__,
"in_channel_list": self.in_channel_list,
"out_channel_list": self.out_channel_list,
"kernel_size_list": self.kernel_size_list,
"expand_ratio_list": self.expand_ratio_list,
"stride": self.stride,
"act_func": self.act_func,
"use_se": self.use_se,
}

@staticmethod
def build_from_config(config):
return DynamicMBConvLayer(**config)

############################################################################################

@property
def in_channels(self):
return max(self.in_channel_list)

@property
def out_channels(self):
return max(self.out_channel_list)

def active_middle_channel(self, in_channel):
return make_divisible(round(in_channel * self.active_expand_ratio))

############################################################################################

def get_active_subnet(self, in_channel, preserve_weight=True):
# build the new layer
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
sub_layer = sub_layer.to(self.parameters().__next__().device)
if not preserve_weight:
return sub_layer

middle_channel = self.active_middle_channel(in_channel)
# copy weight from current layer
if sub_layer.inverted_bottleneck is not None:
sub_layer.inverted_bottleneck.conv.weight.data.copy_(
self.inverted_bottleneck.conv.get_active_filter(
middle_channel, in_channel
).data,
)
copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn)

sub_layer.depth_conv.conv.weight.data.copy_(
self.depth_conv.conv.get_active_filter(
middle_channel, self.active_kernel_size
).data
)
copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn)

if self.use_se:
se_mid = make_divisible(
middle_channel // SEModule.REDUCTION,
)
sub_layer.depth_conv.se.fc.reduce.weight.data.copy_(
self.depth_conv.se.get_active_reduce_weight(se_mid, middle_channel).data
)
sub_layer.depth_conv.se.fc.reduce.bias.data.copy_(
self.depth_conv.se.get_active_reduce_bias(se_mid).data
)

sub_layer.depth_conv.se.fc.expand.weight.data.copy_(
self.depth_conv.se.get_active_expand_weight(se_mid, middle_channel).data
)
sub_layer.depth_conv.se.fc.expand.bias.data.copy_(
self.depth_conv.se.get_active_expand_bias(middle_channel).data
)

sub_layer.point_linear.conv.weight.data.copy_(
self.point_linear.conv.get_active_filter(
self.active_out_channel, middle_channel
).data
)
copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn)

return sub_layer

def get_active_subnet_config(self, in_channel):
return {
"name": MBConvLayer.__name__,
"in_channels": in_channel,
"out_channels": self.active_out_channel,
"kernel_size": self.active_kernel_size,
"stride": self.stride,
"expand_ratio": self.active_expand_ratio,
"mid_channels": self.active_middle_channel(in_channel),
"act_func": self.act_func,
"use_se": self.use_se,
}

def re_organize_middle_weights(self, expand_ratio_stage=0):
importance = torch.sum(
torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3)
)
if isinstance(self.depth_conv.bn, DynamicGroupNorm):
channel_per_group = self.depth_conv.bn.channel_per_group
importance_chunks = torch.split(importance, channel_per_group)
for chunk in importance_chunks:
chunk.data.fill_(torch.mean(chunk))
importance = torch.cat(importance_chunks, dim=0)
if expand_ratio_stage > 0:
sorted_expand_list = deepcopy(self.expand_ratio_list)
sorted_expand_list.sort(reverse=True)
target_width_list = [
make_divisible(round(max(self.in_channel_list) * expand))
for expand in sorted_expand_list
]

right = len(importance)
base = -len(target_width_list) * 1e5
for i in range(expand_ratio_stage + 1):
left = target_width_list[i]
importance[left:right] += base
base += 1e5
right = left

sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
self.point_linear.conv.conv.weight.data = torch.index_select(
self.point_linear.conv.conv.weight.data, 1, sorted_idx
)

adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx)
self.depth_conv.conv.conv.weight.data = torch.index_select(
self.depth_conv.conv.conv.weight.data, 0, sorted_idx
)

if self.use_se:
# se expand: output dim 0 reorganize
se_expand = self.depth_conv.se.fc.expand
se_expand.weight.data = torch.index_select(
se_expand.weight.data, 0, sorted_idx
)
se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx)
# se reduce: input dim 1 reorganize
se_reduce = self.depth_conv.se.fc.reduce
se_reduce.weight.data = torch.index_select(
se_reduce.weight.data, 1, sorted_idx
)
# middle weight reorganize
se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3))
se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True)

se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx)
se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx)
se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx)

if self.inverted_bottleneck is not None:
adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx)
self.inverted_bottleneck.conv.conv.weight.data = torch.index_select(
self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx
)
return None
else:
return sorted_idx


class DynamicConvLayer(nn.Module):
def __init__(
self,
in_channel_list,
out_channel_list,
kernel_size=3,
stride=1,
dilation=1,
use_bn=True,
act_func="relu6",
):
super(DynamicConvLayer, self).__init__()

self.in_channel_list = in_channel_list
self.out_channel_list = out_channel_list
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.use_bn = use_bn
self.act_func = act_func

self.conv = DynamicConv2d(
max_in_channels=max(self.in_channel_list),
max_out_channels=max(self.out_channel_list),
kernel_size=self.kernel_size,
stride=self.stride,
dilation=self.dilation,
)
if self.use_bn:
self.bn = DynamicBatchNorm2d(max(self.out_channel_list))
self.act = build_activation(self.act_func)

self.active_out_channel = max(self.out_channel_list)

def forward(self, x):
self.conv.active_out_channel = self.active_out_channel

x = self.conv(x)
if self.use_bn:
x = self.bn(x)
x = self.act(x)
return x

@property
def module_str(self):
return "DyConv(O%d, K%d, S%d)" % (
self.active_out_channel,
self.kernel_size,
self.stride,
)

@property
def config(self):
return {
"name": DynamicConvLayer.__name__,
"in_channel_list": self.in_channel_list,
"out_channel_list": self.out_channel_list,
"kernel_size": self.kernel_size,
"stride": self.stride,
"dilation": self.dilation,
"use_bn": self.use_bn,
"act_func": self.act_func,
}

@staticmethod
def build_from_config(config):
return DynamicConvLayer(**config)

############################################################################################

@property
def in_channels(self):
return max(self.in_channel_list)

@property
def out_channels(self):
return max(self.out_channel_list)

############################################################################################

def get_active_subnet(self, in_channel, preserve_weight=True):
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
sub_layer = sub_layer.to(self.parameters().__next__().device)

if not preserve_weight:
return sub_layer

sub_layer.conv.weight.data.copy_(
self.conv.get_active_filter(self.active_out_channel, in_channel).data
)
if self.use_bn:
copy_bn(sub_layer.bn, self.bn.bn)

return sub_layer

def get_active_subnet_config(self, in_channel):
return {
"name": ConvLayer.__name__,
"in_channels": in_channel,
"out_channels": self.active_out_channel,
"kernel_size": self.kernel_size,
"stride": self.stride,
"dilation": self.dilation,
"use_bn": self.use_bn,
"act_func": self.act_func,
}


class DynamicResNetBottleneckBlock(nn.Module):
def __init__(
self,
in_channel_list,
out_channel_list,
expand_ratio_list=0.25,
kernel_size=3,
stride=1,
act_func="relu",
downsample_mode="avgpool_conv",
):
super(DynamicResNetBottleneckBlock, self).__init__()

self.in_channel_list = in_channel_list
self.out_channel_list = out_channel_list
self.expand_ratio_list = val2list(expand_ratio_list)

self.kernel_size = kernel_size
self.stride = stride
self.act_func = act_func
self.downsample_mode = downsample_mode

# build modules
max_middle_channel = make_divisible(
round(max(self.out_channel_list) * max(self.expand_ratio_list)))

self.conv1 = nn.Sequential(
OrderedDict(
[
(
"conv",
DynamicConv2d(max(self.in_channel_list), max_middle_channel),
),
("bn", DynamicBatchNorm2d(max_middle_channel)),
("act", build_activation(self.act_func, inplace=True)),
]
)
)

self.conv2 = nn.Sequential(
OrderedDict(
[
(
"conv",
DynamicConv2d(
max_middle_channel, max_middle_channel, kernel_size, stride
),
),
("bn", DynamicBatchNorm2d(max_middle_channel)),
("act", build_activation(self.act_func, inplace=True)),
]
)
)

self.conv3 = nn.Sequential(
OrderedDict(
[
(
"conv",
DynamicConv2d(max_middle_channel, max(self.out_channel_list)),
),
("bn", DynamicBatchNorm2d(max(self.out_channel_list))),
]
)
)

if self.stride == 1 and self.in_channel_list == self.out_channel_list:
self.downsample = IdentityLayer(
max(self.in_channel_list), max(self.out_channel_list)
)
elif self.downsample_mode == "conv":
self.downsample = nn.Sequential(
OrderedDict(
[
(
"conv",
DynamicConv2d(
max(self.in_channel_list),
max(self.out_channel_list),
stride=stride,
),
),
("bn", DynamicBatchNorm2d(max(self.out_channel_list))),
]
)
)
elif self.downsample_mode == "avgpool_conv":
self.downsample = nn.Sequential(
OrderedDict(
[
(
"avg_pool",
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
padding=0,
ceil_mode=True,
),
),
(
"conv",
DynamicConv2d(
max(self.in_channel_list), max(self.out_channel_list)
),
),
("bn", DynamicBatchNorm2d(max(self.out_channel_list))),
]
)
)
else:
raise NotImplementedError

self.final_act = build_activation(self.act_func, inplace=True)

self.active_expand_ratio = max(self.expand_ratio_list)
self.active_out_channel = max(self.out_channel_list)

def forward(self, x):
feature_dim = self.active_middle_channels

self.conv1.conv.active_out_channel = feature_dim
self.conv2.conv.active_out_channel = feature_dim
self.conv3.conv.active_out_channel = self.active_out_channel
if not isinstance(self.downsample, IdentityLayer):
self.downsample.conv.active_out_channel = self.active_out_channel

residual = self.downsample(x)

x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)

x = x + residual
x = self.final_act(x)
return x

@property
def module_str(self):
return "(%s, %s)" % (
"%dx%d_BottleneckConv_in->%d->%d_S%d"
% (
self.kernel_size,
self.kernel_size,
self.active_middle_channels,
self.active_out_channel,
self.stride,
),
"Identity"
if isinstance(self.downsample, IdentityLayer)
else self.downsample_mode,
)

@property
def config(self):
return {
"name": DynamicResNetBottleneckBlock.__name__,
"in_channel_list": self.in_channel_list,
"out_channel_list": self.out_channel_list,
"expand_ratio_list": self.expand_ratio_list,
"kernel_size": self.kernel_size,
"stride": self.stride,
"act_func": self.act_func,
"downsample_mode": self.downsample_mode,
}

@staticmethod
def build_from_config(config):
return DynamicResNetBottleneckBlock(**config)

############################################################################################

@property
def in_channels(self):
return max(self.in_channel_list)

@property
def out_channels(self):
return max(self.out_channel_list)

@property
def active_middle_channels(self):
feature_dim = round(self.active_out_channel * self.active_expand_ratio)
feature_dim = make_divisible(feature_dim)
return feature_dim

############################################################################################

def get_active_subnet(self, in_channel, preserve_weight=True):
# build the new layer
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
sub_layer = sub_layer.to(self.parameters().__next__().device)
if not preserve_weight:
return sub_layer

# copy weight from current layer
sub_layer.conv1.conv.weight.data.copy_(
self.conv1.conv.get_active_filter(
self.active_middle_channels, in_channel
).data
)
copy_bn(sub_layer.conv1.bn, self.conv1.bn.bn)

sub_layer.conv2.conv.weight.data.copy_(
self.conv2.conv.get_active_filter(
self.active_middle_channels, self.active_middle_channels
).data
)
copy_bn(sub_layer.conv2.bn, self.conv2.bn.bn)

sub_layer.conv3.conv.weight.data.copy_(
self.conv3.conv.get_active_filter(
self.active_out_channel, self.active_middle_channels
).data
)
copy_bn(sub_layer.conv3.bn, self.conv3.bn.bn)

if not isinstance(self.downsample, IdentityLayer):
sub_layer.downsample.conv.weight.data.copy_(
self.downsample.conv.get_active_filter(
self.active_out_channel, in_channel
).data
)
copy_bn(sub_layer.downsample.bn, self.downsample.bn.bn)

return sub_layer

def get_active_subnet_config(self, in_channel):
return {
"name": ResNetBottleneckBlock.__name__,
"in_channels": in_channel,
"out_channels": self.active_out_channel,
"kernel_size": self.kernel_size,
"stride": self.stride,
"expand_ratio": self.active_expand_ratio,
"mid_channels": self.active_middle_channels,
"act_func": self.act_func,
"groups": 1,
"downsample_mode": self.downsample_mode,
}

def re_organize_middle_weights(self, expand_ratio_stage=0):
# conv3 -> conv2
importance = torch.sum(
torch.abs(self.conv3.conv.conv.weight.data), dim=(0, 2, 3)
)
if isinstance(self.conv2.bn, DynamicGroupNorm):
channel_per_group = self.conv2.bn.channel_per_group
importance_chunks = torch.split(importance, channel_per_group)
for chunk in importance_chunks:
chunk.data.fill_(torch.mean(chunk))
importance = torch.cat(importance_chunks, dim=0)
if expand_ratio_stage > 0:
sorted_expand_list = deepcopy(self.expand_ratio_list)
sorted_expand_list.sort(reverse=True)
target_width_list = [
make_divisible(round(max(self.out_channel_list) * expand))
for expand in sorted_expand_list
]
right = len(importance)
base = -len(target_width_list) * 1e5
for i in range(expand_ratio_stage + 1):
left = target_width_list[i]
importance[left:right] += base
base += 1e5
right = left

sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
self.conv3.conv.conv.weight.data = torch.index_select(
self.conv3.conv.conv.weight.data, 1, sorted_idx
)
adjust_bn_according_to_idx(self.conv2.bn.bn, sorted_idx)
self.conv2.conv.conv.weight.data = torch.index_select(
self.conv2.conv.conv.weight.data, 0, sorted_idx
)

# conv2 -> conv1
importance = torch.sum(
torch.abs(self.conv2.conv.conv.weight.data), dim=(0, 2, 3)
)
if isinstance(self.conv1.bn, DynamicGroupNorm):
channel_per_group = self.conv1.bn.channel_per_group
importance_chunks = torch.split(importance, channel_per_group)
for chunk in importance_chunks:
chunk.data.fill_(torch.mean(chunk))
importance = torch.cat(importance_chunks, dim=0)
if expand_ratio_stage > 0:
sorted_expand_list = deepcopy(self.expand_ratio_list)
sorted_expand_list.sort(reverse=True)
target_width_list = [
make_divisible(round(max(self.out_channel_list) * expand))
for expand in sorted_expand_list
]
right = len(importance)
base = -len(target_width_list) * 1e5
for i in range(expand_ratio_stage + 1):
left = target_width_list[i]
importance[left:right] += base
base += 1e5
right = left
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)

self.conv2.conv.conv.weight.data = torch.index_select(
self.conv2.conv.conv.weight.data, 1, sorted_idx
)
adjust_bn_according_to_idx(self.conv1.bn.bn, sorted_idx)
self.conv1.conv.conv.weight.data = torch.index_select(
self.conv1.conv.conv.weight.data, 0, sorted_idx
)

return None


def sub_filter_start_end(kernel_size, sub_kernel_size):
center = kernel_size // 2
dev = sub_kernel_size // 2
start, end = center - dev, center + dev + 1
assert end - start == sub_kernel_size
return start, end


def adjust_bn_according_to_idx(bn, idx):
bn.weight.data = torch.index_select(bn.weight.data, 0, idx)
bn.bias.data = torch.index_select(bn.bias.data, 0, idx)
if type(bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
bn.running_mean.data = torch.index_select(bn.running_mean.data, 0, idx)
bn.running_var.data = torch.index_select(bn.running_var.data, 0, idx)


def copy_bn(target_bn, src_bn):
feature_dim = (
target_bn.num_channels
if isinstance(target_bn, nn.GroupNorm)
else target_bn.num_features
)

target_bn.weight.data.copy_(src_bn.weight.data[:feature_dim])
target_bn.bias.data.copy_(src_bn.bias.data[:feature_dim])
if type(src_bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
target_bn.running_mean.data.copy_(src_bn.running_mean.data[:feature_dim])
target_bn.running_var.data.copy_(src_bn.running_var.data[:feature_dim])

+ 957
- 0
xnas/spaces/OFA/ops.py View File

@@ -0,0 +1,957 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

from xnas.spaces.OFA.utils import (
min_divisible_value,
get_same_padding,
make_divisible,
)


def set_layer_from_config(layer_config):
if layer_config is None:
return None

name2layer = {
ConvLayer.__name__: ConvLayer,
IdentityLayer.__name__: IdentityLayer,
LinearLayer.__name__: LinearLayer,
MultiHeadLinearLayer.__name__: MultiHeadLinearLayer,
ZeroLayer.__name__: ZeroLayer,
MBConvLayer.__name__: MBConvLayer,
"MBInvertedConvLayer": MBConvLayer,
ResidualBlock.__name__: ResidualBlock,
ResNetBottleneckBlock.__name__: ResNetBottleneckBlock,
}

layer_name = layer_config.pop("name")
layer = name2layer[layer_name]
return layer.build_from_config(layer_config)


"""Activation."""

def build_activation(act_func, inplace=True):
if act_func == "relu":
return nn.ReLU(inplace=inplace)
elif act_func == "relu6":
return nn.ReLU6(inplace=inplace)
elif act_func == "tanh":
return nn.Tanh()
elif act_func == "sigmoid":
return nn.Sigmoid()
elif act_func == "h_swish":
return Hswish(inplace=inplace)
elif act_func == "h_sigmoid":
return Hsigmoid(inplace=inplace)
elif act_func is None or act_func == "none":
return None
else:
raise ValueError("do not support: %s" % act_func)


class Hswish(nn.Module):
def __init__(self, inplace=True):
super(Hswish, self).__init__()
self.inplace = inplace

def forward(self, x):
return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0

def __repr__(self):
return "Hswish()"


class Hsigmoid(nn.Module):
def __init__(self, inplace=True):
super(Hsigmoid, self).__init__()
self.inplace = inplace

def forward(self, x):
return F.relu6(x + 3.0, inplace=self.inplace) / 6.0

def __repr__(self):
return "Hsigmoid()"



class ShuffleLayer(nn.Module):
def __init__(self, groups):
super(ShuffleLayer, self).__init__()
self.groups = groups

def forward(self, x):
batch_size, num_channels, height, width = x.size()
channels_per_group = num_channels // self.groups
# reshape
x = x.view(batch_size, self.groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batch_size, -1, height, width)
return x

def __repr__(self):
return "ShuffleLayer(groups=%d)" % self.groups


class GlobalAvgPool2d(nn.Module):
def __init__(self, keep_dim=True):
super(GlobalAvgPool2d, self).__init__()
self.keep_dim = keep_dim

def forward(self, x):
return x.mean(3, keepdim=self.keep_dim).mean(2, keepdim=self.keep_dim)

def __repr__(self):
return "GlobalAvgPool2d(keep_dim=%s)" % self.keep_dim


class SEModule(nn.Module):
REDUCTION = 4

def __init__(self, channel, reduction=None):
super(SEModule, self).__init__()

self.channel = channel
self.reduction = SEModule.REDUCTION if reduction is None else reduction

num_mid = make_divisible(self.channel // self.reduction)

self.fc = nn.Sequential(
OrderedDict(
[
("reduce", nn.Conv2d(self.channel, num_mid, 1, 1, 0, bias=True)),
("relu", nn.ReLU(inplace=True)),
("expand", nn.Conv2d(num_mid, self.channel, 1, 1, 0, bias=True)),
("h_sigmoid", Hsigmoid(inplace=True)),
]
)
)

def forward(self, x):
y = x.mean(3, keepdim=True).mean(2, keepdim=True)
y = self.fc(y)
return x * y

def __repr__(self):
return "SE(channel=%d, reduction=%d)" % (self.channel, self.reduction)


class WeightStandardConv2d(nn.Conv2d):
"""
Conv2d with Weight Standardization
https://github.com/joe-siyuan-qiao/WeightStandardization
"""

def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
super(WeightStandardConv2d, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
self.WS_EPS = None

def weight_standardization(self, weight):
if self.WS_EPS is not None:
weight_mean = (
weight.mean(dim=1, keepdim=True)
.mean(dim=2, keepdim=True)
.mean(dim=3, keepdim=True)
)
weight = weight - weight_mean
std = (
weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1)
+ self.WS_EPS
)
weight = weight / std.expand_as(weight)
return weight

def forward(self, x):
if self.WS_EPS is None:
return super(WeightStandardConv2d, self).forward(x)
else:
return F.conv2d(
x,
self.weight_standardization(self.weight),
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)

def __repr__(self):
return super(WeightStandardConv2d, self).__repr__()[:-1] + ", ws_eps=%s)" % self.WS_EPS


# Basic layer to init with Dropout, Conv, BN and activations (order not fixed.)
class MixedDCBRLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
use_bn=True,
act_func="relu",
dropout_rate=0,
ops_order="weight_bn_act",
):
super(MixedDCBRLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels

self.use_bn = use_bn
self.act_func = act_func
self.dropout_rate = dropout_rate
self.ops_order = ops_order

modules = {}
# batch norm
if self.use_bn:
if self.bn_before_weight:
modules["bn"] = nn.BatchNorm2d(in_channels)
else:
modules["bn"] = nn.BatchNorm2d(out_channels)
else:
modules["bn"] = None
# activation
modules["act"] = build_activation(
self.act_func, self.ops_list[0] != "act" and self.use_bn
)
# dropout
if self.dropout_rate > 0:
modules["dropout"] = nn.Dropout2d(self.dropout_rate, inplace=True)
else:
modules["dropout"] = None
# weight
modules["weight"] = self.weight_op()

# add modules
for op in self.ops_list:
if modules[op] is None:
continue
elif op == "weight":
# dropout before weight operation
if modules["dropout"] is not None:
self.add_module("dropout", modules["dropout"])
for key in modules["weight"]:
self.add_module(key, modules["weight"][key])
else:
self.add_module(op, modules[op])

@property
def ops_list(self):
return self.ops_order.split("_")

@property
def bn_before_weight(self):
for op in self.ops_list:
if op == "bn":
return True
elif op == "weight":
return False
raise ValueError("Invalid ops_order: %s" % self.ops_order)

def forward(self, x):
# similar to nn.Sequential
for module in self._modules.values():
x = module(x)
return x

@property
def config(self):
return {
"in_channels": self.in_channels,
"out_channels": self.out_channels,
"use_bn": self.use_bn,
"act_func": self.act_func,
"dropout_rate": self.dropout_rate,
"ops_order": self.ops_order,
}


class ConvLayer(MixedDCBRLayer):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=False,
has_shuffle=False,
use_se=False,
use_bn=True,
act_func="relu",
dropout_rate=0,
ops_order="weight_bn_act",
):
# default normal 3x3_Conv with bn and relu
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.groups = groups
self.bias = bias
self.has_shuffle = has_shuffle
self.use_se = use_se

super(ConvLayer, self).__init__(
in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order
)
if self.use_se:
self.add_module("se", SEModule(self.out_channels))

def weight_op(self):
padding = get_same_padding(self.kernel_size)
if isinstance(padding, int):
padding *= self.dilation
else:
padding[0] *= self.dilation
padding[1] *= self.dilation

weight_dict = OrderedDict(
{
"conv": nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=padding,
dilation=self.dilation,
groups=min_divisible_value(self.in_channels, self.groups),
bias=self.bias,
)
}
)
if self.has_shuffle and self.groups > 1:
weight_dict["shuffle"] = ShuffleLayer(self.groups)

return weight_dict

@property
def module_str(self):
if isinstance(self.kernel_size, int):
kernel_size = (self.kernel_size, self.kernel_size)
else:
kernel_size = self.kernel_size
if self.groups == 1:
if self.dilation > 1:
conv_str = "%dx%d_DilatedConv" % (kernel_size[0], kernel_size[1])
else:
conv_str = "%dx%d_Conv" % (kernel_size[0], kernel_size[1])
else:
if self.dilation > 1:
conv_str = "%dx%d_DilatedGroupConv" % (kernel_size[0], kernel_size[1])
else:
conv_str = "%dx%d_GroupConv" % (kernel_size[0], kernel_size[1])
conv_str += "_O%d" % self.out_channels
if self.use_se:
conv_str = "SE_" + conv_str
conv_str += "_" + self.act_func.upper()
if self.use_bn:
if isinstance(self.bn, nn.GroupNorm):
conv_str += "_GN%d" % self.bn.num_groups
elif isinstance(self.bn, nn.BatchNorm2d):
conv_str += "_BN"
return conv_str

@property
def config(self):
return {
"name": ConvLayer.__name__,
"kernel_size": self.kernel_size,
"stride": self.stride,
"dilation": self.dilation,
"groups": self.groups,
"bias": self.bias,
"has_shuffle": self.has_shuffle,
"use_se": self.use_se,
**super(ConvLayer, self).config,
}


class IdentityLayer(MixedDCBRLayer):
def __init__(
self,
in_channels,
out_channels,
use_bn=False,
act_func=None,
dropout_rate=0,
ops_order="weight_bn_act",
):
super(IdentityLayer, self).__init__(
in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order
)

def weight_op(self):
return None

@property
def module_str(self):
return "Identity"

@property
def config(self):
return {
"name": IdentityLayer.__name__,
**super(IdentityLayer, self).config,
}

@staticmethod
def build_from_config(config):
return IdentityLayer(**config)


class LinearLayer(nn.Module):
def __init__(
self,
in_features,
out_features,
bias=True,
use_bn=False,
act_func=None,
dropout_rate=0,
ops_order="weight_bn_act",
):
super(LinearLayer, self).__init__()

self.in_features = in_features
self.out_features = out_features
self.bias = bias

self.use_bn = use_bn
self.act_func = act_func
self.dropout_rate = dropout_rate
self.ops_order = ops_order

""" modules """
modules = {}
# batch norm
if self.use_bn:
if self.bn_before_weight:
modules["bn"] = nn.BatchNorm1d(in_features)
else:
modules["bn"] = nn.BatchNorm1d(out_features)
else:
modules["bn"] = None
# activation
modules["act"] = build_activation(self.act_func, self.ops_list[0] != "act")
# dropout
if self.dropout_rate > 0:
modules["dropout"] = nn.Dropout(self.dropout_rate, inplace=True)
else:
modules["dropout"] = None
# linear
modules["weight"] = {
"linear": nn.Linear(self.in_features, self.out_features, self.bias)
}

# add modules
for op in self.ops_list:
if modules[op] is None:
continue
elif op == "weight":
if modules["dropout"] is not None:
self.add_module("dropout", modules["dropout"])
for key in modules["weight"]:
self.add_module(key, modules["weight"][key])
else:
self.add_module(op, modules[op])

@property
def ops_list(self):
return self.ops_order.split("_")

@property
def bn_before_weight(self):
for op in self.ops_list:
if op == "bn":
return True
elif op == "weight":
return False
raise ValueError("Invalid ops_order: %s" % self.ops_order)

def forward(self, x):
for module in self._modules.values():
x = module(x)
return x

@property
def module_str(self):
return "%dx%d_Linear" % (self.in_features, self.out_features)

@property
def config(self):
return {
"name": LinearLayer.__name__,
"in_features": self.in_features,
"out_features": self.out_features,
"bias": self.bias,
"use_bn": self.use_bn,
"act_func": self.act_func,
"dropout_rate": self.dropout_rate,
"ops_order": self.ops_order,
}

@staticmethod
def build_from_config(config):
return LinearLayer(**config)


class MultiHeadLinearLayer(nn.Module):
def __init__(
self, in_features, out_features, num_heads=1, bias=True, dropout_rate=0
):
super(MultiHeadLinearLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.num_heads = num_heads

self.bias = bias
self.dropout_rate = dropout_rate

if self.dropout_rate > 0:
self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
else:
self.dropout = None

self.layers = nn.ModuleList()
for k in range(num_heads):
layer = nn.Linear(in_features, out_features, self.bias)
self.layers.append(layer)

def forward(self, inputs):
if self.dropout is not None:
inputs = self.dropout(inputs)

outputs = []
for layer in self.layers:
output = layer.forward(inputs)
outputs.append(output)

outputs = torch.stack(outputs, dim=1)
return outputs

@property
def module_str(self):
return self.__repr__()

@property
def config(self):
return {
"name": MultiHeadLinearLayer.__name__,
"in_features": self.in_features,
"out_features": self.out_features,
"num_heads": self.num_heads,
"bias": self.bias,
"dropout_rate": self.dropout_rate,
}

@staticmethod
def build_from_config(config):
return MultiHeadLinearLayer(**config)

def __repr__(self):
return (
"MultiHeadLinear(in_features=%d, out_features=%d, num_heads=%d, bias=%s, dropout_rate=%s)" % (
self.in_features,
self.out_features,
self.num_heads,
self.bias,
self.dropout_rate,
))


class ZeroLayer(nn.Module):
def __init__(self):
super(ZeroLayer, self).__init__()

def forward(self, x):
raise ValueError

@property
def module_str(self):
return "Zero"

@property
def config(self):
return {
"name": ZeroLayer.__name__,
}

@staticmethod
def build_from_config(config):
return ZeroLayer()


class MBConvLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
expand_ratio=6,
mid_channels=None,
act_func="relu6",
use_se=False,
groups=None,
):
super(MBConvLayer, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels

self.kernel_size = kernel_size
self.stride = stride
self.expand_ratio = expand_ratio
self.mid_channels = mid_channels
self.act_func = act_func
self.use_se = use_se
self.groups = groups

if self.mid_channels is None:
feature_dim = round(self.in_channels * self.expand_ratio)
else:
feature_dim = self.mid_channels

if self.expand_ratio == 1:
self.inverted_bottleneck = None
else:
self.inverted_bottleneck = nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(
self.in_channels, feature_dim, 1, 1, 0, bias=False
),
),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]
)
)

pad = get_same_padding(self.kernel_size)
groups = (
feature_dim
if self.groups is None
else min_divisible_value(feature_dim, self.groups)
)
depth_conv_modules = [
(
"conv",
nn.Conv2d(
feature_dim,
feature_dim,
kernel_size,
stride,
pad,
groups=groups,
bias=False,
),
),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]
if self.use_se:
depth_conv_modules.append(("se", SEModule(feature_dim)))
self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules))

self.point_linear = nn.Sequential(
OrderedDict(
[
("conv", nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
("bn", nn.BatchNorm2d(out_channels)),
]
)
)

def forward(self, x):
if self.inverted_bottleneck:
x = self.inverted_bottleneck(x)
x = self.depth_conv(x)
x = self.point_linear(x)
return x

@property
def module_str(self):
if self.mid_channels is None:
expand_ratio = self.expand_ratio
else:
expand_ratio = self.mid_channels // self.in_channels
layer_str = "%dx%d_MBConv%d_%s" % (
self.kernel_size,
self.kernel_size,
expand_ratio,
self.act_func.upper(),
)
if self.use_se:
layer_str = "SE_" + layer_str
layer_str += "_O%d" % self.out_channels
if self.groups is not None:
layer_str += "_G%d" % self.groups
if isinstance(self.point_linear.bn, nn.GroupNorm):
layer_str += "_GN%d" % self.point_linear.bn.num_groups
elif isinstance(self.point_linear.bn, nn.BatchNorm2d):
layer_str += "_BN"

return layer_str

@property
def config(self):
return {
"name": MBConvLayer.__name__,
"in_channels": self.in_channels,
"out_channels": self.out_channels,
"kernel_size": self.kernel_size,
"stride": self.stride,
"expand_ratio": self.expand_ratio,
"mid_channels": self.mid_channels,
"act_func": self.act_func,
"use_se": self.use_se,
"groups": self.groups,
}

@staticmethod
def build_from_config(config):
return MBConvLayer(**config)


class ResidualBlock(nn.Module):
def __init__(self, conv, shortcut):
super(ResidualBlock, self).__init__()

self.conv = conv
self.shortcut = shortcut

def forward(self, x):
if self.conv is None or isinstance(self.conv, ZeroLayer):
res = x
elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
res = self.conv(x)
else:
res = self.conv(x) + self.shortcut(x)
return res

@property
def module_str(self):
return "(%s, %s)" % (
self.conv.module_str if self.conv is not None else None,
self.shortcut.module_str if self.shortcut is not None else None,
)

@property
def config(self):
return {
"name": ResidualBlock.__name__,
"conv": self.conv.config if self.conv is not None else None,
"shortcut": self.shortcut.config if self.shortcut is not None else None,
}

@staticmethod
def build_from_config(config):
conv_config = (
config["conv"] if "conv" in config else config["mobile_inverted_conv"]
)
conv = set_layer_from_config(conv_config)
shortcut = set_layer_from_config(config["shortcut"])
return ResidualBlock(conv, shortcut)

@property
def mobile_inverted_conv(self):
return self.conv


class ResNetBottleneckBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
expand_ratio=0.25,
mid_channels=None,
act_func="relu",
groups=1,
downsample_mode="avgpool_conv",
):
super(ResNetBottleneckBlock, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels

self.kernel_size = kernel_size
self.stride = stride
self.expand_ratio = expand_ratio
self.mid_channels = mid_channels
self.act_func = act_func
self.groups = groups

self.downsample_mode = downsample_mode

if self.mid_channels is None:
feature_dim = round(self.out_channels * self.expand_ratio)
else:
feature_dim = self.mid_channels

feature_dim = make_divisible(feature_dim)
self.mid_channels = feature_dim

# build modules
self.conv1 = nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False),
),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]
)
)

pad = get_same_padding(self.kernel_size)
self.conv2 = nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(
feature_dim,
feature_dim,
kernel_size,
stride,
pad,
groups=groups,
bias=False,
),
),
("bn", nn.BatchNorm2d(feature_dim)),
("act", build_activation(self.act_func, inplace=True)),
]
)
)

self.conv3 = nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(feature_dim, self.out_channels, 1, 1, 0, bias=False),
),
("bn", nn.BatchNorm2d(self.out_channels)),
]
)
)

if stride == 1 and in_channels == out_channels:
self.downsample = IdentityLayer(in_channels, out_channels)
elif self.downsample_mode == "conv":
self.downsample = nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(
in_channels, out_channels, 1, stride, 0, bias=False
),
),
("bn", nn.BatchNorm2d(out_channels)),
]
)
)
elif self.downsample_mode == "avgpool_conv":
self.downsample = nn.Sequential(
OrderedDict(
[
(
"avg_pool",
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
padding=0,
ceil_mode=True,
),
),
(
"conv",
nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
),
("bn", nn.BatchNorm2d(out_channels)),
]
)
)
else:
raise NotImplementedError

self.final_act = build_activation(self.act_func, inplace=True)

def forward(self, x):
residual = self.downsample(x)

x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)

x = x + residual
x = self.final_act(x)
return x

@property
def module_str(self):
return "(%s, %s)" % (
"%dx%d_BottleneckConv_%d->%d->%d_S%d_G%d"
% (
self.kernel_size,
self.kernel_size,
self.in_channels,
self.mid_channels,
self.out_channels,
self.stride,
self.groups,
),
"Identity"
if isinstance(self.downsample, IdentityLayer)
else self.downsample_mode,
)

@property
def config(self):
return {
"name": ResNetBottleneckBlock.__name__,
"in_channels": self.in_channels,
"out_channels": self.out_channels,
"kernel_size": self.kernel_size,
"stride": self.stride,
"expand_ratio": self.expand_ratio,
"mid_channels": self.mid_channels,
"act_func": self.act_func,
"groups": self.groups,
"downsample_mode": self.downsample_mode,
}

@staticmethod
def build_from_config(config):
return ResNetBottleneckBlock(**config)

+ 111
- 0
xnas/spaces/OFA/utils.py View File

@@ -0,0 +1,111 @@
"""Utilities."""

import math
import numpy as np
import torch.nn as nn


def list_sum(x):
return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])


def list_mean(x):
return list_sum(x) / len(x)


def min_divisible_value(n1, v1):
"""make sure v1 is divisible by n1, otherwise decrease v1"""
if v1 >= n1:
return n1
while n1 % v1 != 0:
v1 -= 1
return v1

def val2list(val, repeat_time=1):
if isinstance(val, list) or isinstance(val, np.ndarray):
return val
elif isinstance(val, tuple):
return list(val)
else:
return [val for _ in range(repeat_time)]


""" Layer releated"""

def get_same_padding(kernel_size):
if isinstance(kernel_size, tuple):
assert len(kernel_size) == 2, "invalid kernel size: %s" % kernel_size
p1 = get_same_padding(kernel_size[0])
p2 = get_same_padding(kernel_size[1])
return p1, p2
assert isinstance(kernel_size, int), "kernel size should be either `int` or `tuple`"
assert kernel_size % 2 > 0, "kernel size should be odd number"
return kernel_size // 2

def make_divisible(v, divisor=8, min_val=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_val:
:return:
"""
if min_val is None:
min_val = divisor
new_v = max(min_val, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v


""" BN related """

def clean_num_batch_tracked(net):
for m in net.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
if m.num_batches_tracked is not None:
m.num_batches_tracked.zero_()


def rm_bn_from_net(net):
for m in net.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
m.forward = lambda x: x


""" network related """

def init_model(net, model_init="he_fout"):
"""
Conv2d,
BatchNorm2d, BatchNorm1d, GroupNorm
Linear,
"""
if isinstance(net, list):
for sub_net in net:
init_models(sub_net, model_init)
return
for m in net.modules():
if isinstance(m, nn.Conv2d):
if model_init == "he_fout":
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
elif model_init == "he_fin":
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
else:
raise NotImplementedError
if m.bias is not None:
m.bias.data.zero_()
elif type(m) in [nn.BatchNorm2d, nn.BatchNorm1d, nn.GroupNorm]:
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
stdv = 1.0 / math.sqrt(m.weight.size(1))
m.weight.data.uniform_(-stdv, stdv)
if m.bias is not None:
m.bias.data.zero_()

+ 3
- 0
xnas/spaces/SPOS/cnn.py View File

@@ -176,6 +176,9 @@ class SPOS_supernet(nn.Module):
# self.global_pooling = nn.AvgPool2d(7)
self.classifier = nn.Linear(LAST_CHANNEL, self.classes, bias=False)
self._initialize_weights()
def weights(self):
return self.parameters()

def forward(self, x, choice=np.random.randint(4, size=20)):
x = self.stem(x)


Loading…
Cancel
Save