From e9f4b3c81d39e55232c5af15358f5db69462fe64 Mon Sep 17 00:00:00 2001 From: chouxianyu <774675548@qq.com> Date: Sun, 22 May 2022 12:35:16 +0800 Subject: [PATCH 1/5] refactor docs for sphinx --- .gitignore | 1 + .readthedocs.yml | 9 +++ docs/Data_preparation.md | 14 ++++- docs/conf.py | 70 +++++++++++++++++++++ docs/{Getting_started.md => get_started.md} | 14 ++--- docs/index.rst | 35 +++++++++++ docs/{Contributing.md => notes.md} | 4 +- docs/requirements.txt | 4 ++ 8 files changed, 140 insertions(+), 11 deletions(-) create mode 100644 .readthedocs.yml create mode 100644 docs/conf.py rename docs/{Getting_started.md => get_started.md} (93%) create mode 100644 docs/index.rst rename docs/{Contributing.md => notes.md} (90%) create mode 100644 docs/requirements.txt diff --git a/.gitignore b/.gitignore index caa76dd..a9a95b4 100644 --- a/.gitignore +++ b/.gitignore @@ -83,6 +83,7 @@ instance/ # Sphinx documentation docs/_build/ +docs/build/ # PyBuilder target/ diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 0000000..b23b134 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,9 @@ +version: 2 + +sphinx: + configuration: docs/conf.py + +python: + version: 3.7 + install: + - requirements: docs/requirements.txt \ No newline at end of file diff --git a/docs/Data_preparation.md b/docs/Data_preparation.md index cc39354..5aa4091 100644 --- a/docs/Data_preparation.md +++ b/docs/Data_preparation.md @@ -1,4 +1,4 @@ -## Data Preparation +# Common Settings It is highly recommended to save or link datasets to the `$XNAS/data` folder, thus no additional configuration is required. @@ -7,7 +7,7 @@ However, manually setting the path for datasets is also available by modifying t Additionally, files required by benchmarks are also in the `$XNAS/data` folder. You can also modify related attributes under `cfg.BENCHMARK` in the configuration file, to match your actual file locations. -### Supported Datasets +# Dataset Preparation The dataloaders of XNAS will read the dataset files from `$XNAS/data/$DATASET_NAME` by default, and we use lowercase filenames and remove the hyphens. For example, files for CIFAR-10 should be placed (or auto downloaded) under `$XNAS/data/cifar/` directory. @@ -21,3 +21,13 @@ XNAS currently supports the following datasets. - MNIST - FashionMNIST +# Benchmark Preparation + +Some search spaces or algorithms supported by XNAS require specific APIs provided by NAS benchmarks. Installation and properly setting are required to run these code. + +Benchmarks supported by XNAS and their linkes are following. +- nasbench101: [GitHub](https://github.com/google-research/nasbench) + - nasbench1shot1: [GitHub](https://github.com/automl/nasbench-1shot1) +- nasbench201: [GitHub](https://github.com/D-X-Y/NAS-Bench-201) +- nasbench301: [GitHub](https://github.com/automl/nasbench301) + diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..0246723 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,70 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + + +# -- Project information ----------------------------------------------------- + +project = 'XNAS' +copyright = '2022, PCL_AutoML' +author = 'PCL_AutoML' + +# The full version, including alpha/beta/rc tags +release = 'v0.0.1' + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'recommonmark', + 'sphinx_markdown_tables' +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = 'en' +# language = 'zh_CN' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# + +# html_theme = 'alabaster' +# html_theme = 'sphinx_rtd_theme' +html_theme = 'furo' + + + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] \ No newline at end of file diff --git a/docs/Getting_started.md b/docs/get_started.md similarity index 93% rename from docs/Getting_started.md rename to docs/get_started.md index 366137c..ff31a13 100644 --- a/docs/Getting_started.md +++ b/docs/get_started.md @@ -1,10 +1,10 @@ -## Getting Started +# Prerequisites XNAS does not provide installation via `pip` currently. To run XNAS, `python>=3.7` and `pytorch==1.9` are required. Other versions of `PyTorch` may also work well, but there are potential API differences that can cause warnings to be generated. We have listed other requirements in `requirements.txt` file. -## Installation +# Installation 1. Clone this repo. 2. (Optional) Create a virtualenv for this library. @@ -38,16 +38,16 @@ Benchmarks supported by XNAS and their linkes are following. - nasbench201: [GitHub](https://github.com/D-X-Y/NAS-Bench-201) - nasbench301: [GitHub](https://github.com/automl/nasbench301) -For detailed instructions to install these benchmarks, please refer to the `$XNAS/docs/benchmarks` directory. +For detailed instructions to install these benchmarks, please refer to [**Data Preparation**](./data_preparation.md). -## Usage +# Usage -Before running code in XNAS, please make sure you have followed instructions in [**Data_preparation.md**](./Data_preparation.md) in our docs to complete preparing the necessary data. +Before running code in XNAS, please make sure you have followed instructions in [**Data Preparation**](./data_preparation.md) in our docs to complete preparing the necessary data. The main program entries for the search and training process are in the `$XNAS/scripts` folder. To modify and add NAS code, please place files in this folder. -### Configuration Files +## Configuration Files XNAS uses the `.yaml` file format to organize the configuration files. All configuration files are placed under `$XNAS/configs` directory. To ensure the uniformity and clarity of files, we strongly recommend using the following naming convention: @@ -61,7 +61,7 @@ For example, using `DARTS` algorithm, searching on `NASBench201` space and `CIFA darts_nasbench201_cifar10_nasbench301_maxepoch75.yaml ``` -### Running Examples +## Running Examples XNAS reads configuration files from the command line. A simple running example is following: diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..760ec12 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,35 @@ +.. XNAS documentation master file, created by + sphinx-quickstart on Sat May 21 14:09:17 2022. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + + +Welcome to XNAS' documentation! +==================================== + + +.. toctree:: + :maxdepth: 2 + :caption: Get Started + + get_started.md + +.. toctree:: + :maxdepth: 2 + :caption: Data Preparation + + data_preparation.md + + +.. toctree:: + :maxdepth: 2 + :caption: Notes + + notes.md + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`search` \ No newline at end of file diff --git a/docs/Contributing.md b/docs/notes.md similarity index 90% rename from docs/Contributing.md rename to docs/notes.md index d0dbfbf..f084067 100644 --- a/docs/Contributing.md +++ b/docs/notes.md @@ -1,8 +1,8 @@ -## Contributing to XNAS +# Contributing We welcome contributions to the library along with any potential issues or suggestions. -### Pull Requests +## Pull Requests We actively welcome your pull requests. diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..1d2c10c --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,4 @@ +sphinx +furo +recommonmark +sphinx_markdown_tables -- 2.34.1 From 942397b4d78d09a1df7e9a34dabe64a5719031b7 Mon Sep 17 00:00:00 2001 From: chouxianyu Date: Sun, 22 May 2022 12:37:10 +0800 Subject: [PATCH 2/5] rename 'docs/Data_preparation.md' to 'docs/data_preparation.md' --- docs/{Data_preparation.md => data_preparation.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/{Data_preparation.md => data_preparation.md} (100%) diff --git a/docs/Data_preparation.md b/docs/data_preparation.md similarity index 100% rename from docs/Data_preparation.md rename to docs/data_preparation.md -- 2.34.1 From 307a6932d8d857d18321520ded4e37fee9f66039 Mon Sep 17 00:00:00 2001 From: xfey Date: Tue, 24 May 2022 12:12:10 +0800 Subject: [PATCH 3/5] fix SPOS and add SNG_search --- README.md | 1 - scripts/search/SNG/search.py | 169 +++++++++++++++++++++++++++++++++++ scripts/search/SPOS.py | 18 ++-- xnas/core/builder.py | 4 +- xnas/runner/criterion.py | 55 ++++++++++++ xnas/runner/optimizer.py | 16 +--- xnas/runner/trainer.py | 60 +++++++------ xnas/spaces/DARTS/cnn.py | 3 + xnas/spaces/SPOS/cnn.py | 3 + 9 files changed, 282 insertions(+), 47 deletions(-) create mode 100644 xnas/runner/criterion.py diff --git a/README.md b/README.md index 52168d7..7463c5c 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,6 @@ This project is released under the [MIT license](https://mit-license.org). - 迁移OFA代码 - 补充101&201安装测试 - - 补完SNG Search代码 - 检查201搜索空间 - 检查RMINAS - 补充模块测试案例 diff --git a/scripts/search/SNG/search.py b/scripts/search/SNG/search.py index e69de29..4fcd3fc 100755 --- a/scripts/search/SNG/search.py +++ b/scripts/search/SNG/search.py @@ -0,0 +1,169 @@ +"""SNG searching + +(DARTS space only) +""" + +import torch +import random +import numpy as np + +import xnas.core.config as config +import xnas.logger.meter as meter +import xnas.logger.logging as logging +from xnas.core.config import cfg +from xnas.core.builder import * +from xnas.core.utils import index_to_one_hot, one_hot_to_index + +from xnas.runner.trainer import OneShotTrainer + + +# Load config and check +config.load_configs() +logger = logging.get_logger(__name__) + + +def main(): + setup_env() + search_space = space_builder().cuda() + criterion = criterion_builder().cuda() + evaluator = evaluator_builder() + [train_loader, valid_loader] = construct_loader() + + w_optim = optimizer_builder("SGD", search_space.parameters()) + lr_scheduler = lr_scheduler_builder(w_optim) + + if cfg.SPACE.NAME in ['darts']: + distribution_optimizer = SNG_builder([search_space.num_ops]*search_space.all_edges) + else: + raise NotImplementedError + + # Trainer definition + sng_trainer = OneShotTrainer( + supernet=search_space, + criterion=criterion, + optimizer=w_optim, + lr_scheduler=lr_scheduler, + train_loader=train_loader, + test_loader=valid_loader, + ) + + over_all_epoch = 0 + + # === Warmup === + logger.info("=== Warmup Training ===") + sng_trainer.start() + for cur_epoch in range(cfg.OPTIM.WARMUP_EPOCH): + if cfg.SNG.WARMUP_RANDOM_SAMPLE: + sample = random_sampling(search_space, distribution_optimizer, cur_epoch) + else: + num_ops, total_edges = search_space.num_ops, search_space.all_edges + array_sample = [random.sample(list(range(num_ops)), num_ops) for i in range(total_edges)] + array_sample = np.array(array_sample) + for i in range(num_ops): + sample = np.transpose(array_sample[:, i]) + sample = index_to_one_hot(sample, distribution_optimizer.p_model.Cmax) + logger.info("Warmup Sampling: {}".format(one_hot_to_index(sample))) + sng_trainer.train_epoch(over_all_epoch, sample) + sng_trainer.test_epoch(over_all_epoch, sample) + over_all_epoch += 1 + sng_trainer.finish() + + logger.info("=== Training ===") + sng_trainer.start() + for cur_epoch in range(cfg.OPTIM.MAX_EPOCH): + if hasattr(distribution_optimizer, 'training_finish'): + if distribution_optimizer.training_finish: + break + sample = random_sampling(search_space, distribution_optimizer, epoch=cur_epoch, _random=cfg.SNG.RANDOM_SAMPLE) + logger.info("Sampling: {}".format(one_hot_to_index(sample))) + sng_trainer.train_epoch(over_all_epoch, sample) + top1_err = sng_trainer.test_epoch(over_all_epoch, sample) + over_all_epoch += 1 + # TODO: REA & RAND in algorithm/SPOS are similar to this optimizer. Adding them to SNG series? + distribution_optimizer.record_information(sample, top1_err) + distribution_optimizer.update() + # Evaluate the model + if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH: + logger.info("=== Optimal genotype at epoch: {} ===".format(cur_epoch)) + logger.info(search_space.genotype(distribution_optimizer.p_model.theta)) + logger.info("=== alphas at epoch: {} ===".format(cur_epoch)) + for alpha in distribution_optimizer.p_model.theta: + logger.info(alpha) + sng_trainer.finish() + + logger.info("=== Final epochs ===") + sng_trainer.start() + for cur_epoch in range(cfg.OPTIM.FINAL_EPOCH): + if cfg.SPACE.NAME in ['darts']: + genotype = search_space.genotype(distribution_optimizer.p_model.theta) + sample = search_space.genotype_to_onehot_sample(genotype) + else: + sample = distribution_optimizer.sampling_best() + sng_trainer.train_epoch(over_all_epoch, sample) + sng_trainer.test_epoch(over_all_epoch, sample) + over_all_epoch += 1 + sng_trainer.finish() + + if cfg.SPACE.NAME in ['darts']: + best_genotype = search_space.genotype(distribution_optimizer.p_model.theta) + # evaluator(genotype) # TODO: NAS-Bench-301 support. + + +def random_sampling(search_space, distribution_optimizer, epoch=-1000, _random=False): + """random sampling""" + if _random: + num_ops, total_edges = search_space.num_ops, search_space.all_edges + # Edge importance + non_edge_idx = [] + if cfg.SNG.EDGE_SAMPLING and epoch > cfg.SNG.EDGE_SAMPLING_EPOCH: + assert cfg.SPACE.NAME in ['darts'], "only support darts for now!" + norm_indexes = search_space.norm_node_index + non_edge_idx = [] + for node in norm_indexes: + # DARTS: N=7 nodes + edge_non_prob = distribution_optimizer.p_model.theta[np.array(node), 7] + edge_non_prob = edge_non_prob / np.sum(edge_non_prob) + if len(node) == 2: + pass + else: + non_edge_sampling_num = len(node) - 2 + non_edge_idx += list(np.random.choice(node, non_edge_sampling_num, p=edge_non_prob, replace=False)) + # Big model sampling with probability + if random.random() < cfg.SNG.BIGMODEL_SAMPLE_PROB: + # Sample the network with high complexity + _num = 100 + while _num > cfg.SNG.BIGMODEL_NON_PARA: + _error = False + if cfg.SNG.PROB_SAMPLING: + sample = np.array([np.random.choice(num_ops, 1, p=distribution_optimizer.p_model.theta[i, :])[0] for i in range(total_edges)]) + else: + sample = np.array([np.random.choice(num_ops, 1)[0] for i in range(total_edges)]) + _num = 0 + for i in sample[0:search_space.num_edges]: + if i in non_edge_idx: + pass + elif i in search_space.non_op_idx: + if i == 7: + _error = True + _num = _num + 1 + if _error: + _num = 100 + else: + if cfg.SNG.PROB_SAMPLING: + sample = np.array([np.random.choice(num_ops, 1, p=distribution_optimizer.p_model.theta[i, :])[0] + for i in range(total_edges)]) + else: + sample = np.array([np.random.choice(num_ops, 1)[0] for i in range(total_edges)]) + if cfg.SNG.EDGE_SAMPLING and epoch > cfg.SNG.EDGE_SAMPLING_EPOCH: + for i in non_edge_idx: + sample[i] = 7 + sample = index_to_one_hot(sample, distribution_optimizer.p_model.Cmax) + # in the pruning method we have to sampling anyway + distribution_optimizer.sampling() + return sample + else: + return distribution_optimizer.sampling() + + +if __name__ == "__main__": + main() diff --git a/scripts/search/SPOS.py b/scripts/search/SPOS.py index b341d15..94cd839 100755 --- a/scripts/search/SPOS.py +++ b/scripts/search/SPOS.py @@ -23,7 +23,8 @@ def main(): lr_scheduler = lr_scheduler_builder(optimizer) # init sampler - train_sampler, evaluate_sampler = RAND(), REA() + train_sampler = RAND(cfg.SPOS.NUM_CHOICE, cfg.SPOS.LAYERS) + evaluate_sampler = REA(cfg.SPOS.NUM_CHOICE, cfg.SPOS.LAYERS) # init recorders spos_trainer = OneShotTrainer( @@ -33,9 +34,9 @@ def main(): lr_scheduler=lr_scheduler, train_loader=train_loader, test_loader=valid_loader, - train_sampler=train_sampler, - evaluate_sampler=evaluate_sampler, + sample_type='iter' ) + spos_trainer.register_iter_sample(train_sampler) # load checkpoint or initial weights start_epoch = spos_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0 @@ -44,16 +45,19 @@ def main(): spos_trainer.start() for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # train epoch - spos_trainer.train_epoch(cur_epoch) + top1_err = spos_trainer.train_epoch(cur_epoch) # test epoch if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH: - spos_trainer.test_epoch(cur_epoch) + top1_err = spos_trainer.test_epoch(cur_epoch) spos_trainer.finish() # sample best architecture from supernet - best_arch, best_top1err = spos_trainer.best_arch() + for cycle in range(200): # NOTE: this should be a hyperparameter + sample = evaluate_sampler.suggest() + top1_err = spos_trainer.evaluate_epoch(sample) + evaluate_sampler.record(sample, top1_err) + best_arch, best_top1err = evaluate_sampler.final_best() logger.info("Best arch: {} \nTop1 error: {}".format(best_arch, best_top1err)) - if __name__ == '__main__': main() diff --git a/xnas/core/builder.py b/xnas/core/builder.py index 091bca2..d7491bf 100644 --- a/xnas/core/builder.py +++ b/xnas/core/builder.py @@ -21,7 +21,8 @@ from xnas.core.config import cfg # Dataloader from xnas.datasets.loader import construct_loader # Optimizers, criterions and LR_schedulers -from xnas.runner.optimizer import optimizer_builder, criterion_builder +from xnas.runner.optimizer import optimizer_builder +from xnas.runner.criterion import criterion_builder from xnas.runner.scheduler import lr_scheduler_builder @@ -31,6 +32,7 @@ __all__ = [ 'criterion_builder', 'lr_scheduler_builder', 'space_builder', + 'SNG_builder', 'evaluator_builder', 'setup_env', ] diff --git a/xnas/runner/criterion.py b/xnas/runner/criterion.py new file mode 100644 index 0000000..8b33e36 --- /dev/null +++ b/xnas/runner/criterion.py @@ -0,0 +1,55 @@ +"""Loss functions.""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from xnas.core.config import cfg + + +__all__ = ['criterion_builder'] + + +def smoothed_cross_entropy_loss(pred, target, label_smoothing=0.): + def _label_smooth(target, n_classes: int, label_smoothing): + # convert to one-hot + batch_size = target.size(0) + target = torch.unsqueeze(target, 1) + soft_target = torch.zeros((batch_size, n_classes), device=target.device) + soft_target.scatter_(1, target, 1) + # label smoothing + soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes + return soft_target + + label_smoothing = cfg.SEARCH.LABEL_SMOOTH if label_smoothing == 0. else label_smoothing + soft_target = _label_smooth(target, pred.size(1), label_smoothing) + logsoftmax = nn.LogSoftmax() + return torch.mean(torch.sum(-soft_target * logsoftmax(pred), 1)) + + +class MultiHeadCrossEntropyLoss(nn.Module): + def forward(self, preds, targets): + assert preds.dim() == 3, preds + assert targets.dim() == 2, targets + + assert preds.size(1) == targets.size(1), (preds, targets) + num_heads = targets.size(1) + + loss = 0 + for k in range(num_heads): + loss += F.cross_entropy(preds[:, k, :], targets[:, k]) / num_heads + return loss + + +# ---------- + +SUPPORTED_CRITERIONS = { + "cross_entropy": torch.nn.CrossEntropyLoss(), + "cross_entropy_smooth": smoothed_cross_entropy_loss, + "cross_entropy_multihead": MultiHeadCrossEntropyLoss() +} + + +def criterion_builder(): + err_str = "Loss function type '{}' not supported" + assert cfg.SEARCH.LOSS_FUN in SUPPORTED_CRITERIONS.keys(), err_str.format(cfg.SEARCH.LOSS_FUN) + return SUPPORTED_CRITERIONS[cfg.SEARCH.LOSS_FUN] diff --git a/xnas/runner/optimizer.py b/xnas/runner/optimizer.py index 3bf8337..e523705 100644 --- a/xnas/runner/optimizer.py +++ b/xnas/runner/optimizer.py @@ -1,13 +1,13 @@ -"""Optimizers and loss functions.""" +"""Optimizers.""" import torch +import torch.nn as nn from xnas.core.config import cfg __all__ = [ 'optimizer_builder', - 'darts_alpha_optimizer', - 'criterion_builder' + 'darts_alpha_optimizer', ] @@ -16,10 +16,6 @@ SUPPORTED_OPTIMIZERS = { "Adam", } -SUPPORTED_CRITERIONS = { - "cross_entropy": torch.nn.CrossEntropyLoss(), -} - def optimizer_builder(name, param): """optimizer builder @@ -69,9 +65,3 @@ def darts_alpha_optimizer(name, param): betas=(0.5, 0.999), weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY, ) - - -def criterion_builder(): - err_str = "Loss function type '{}' not supported" - assert cfg.SEARCH.LOSS_FUN in SUPPORTED_CRITERIONS.keys(), err_str.format(cfg.SEARCH.LOSS_FUN) - return SUPPORTED_CRITERIONS[cfg.SEARCH.LOSS_FUN] diff --git a/xnas/runner/trainer.py b/xnas/runner/trainer.py index 2e77c67..a0786ca 100644 --- a/xnas/runner/trainer.py +++ b/xnas/runner/trainer.py @@ -47,6 +47,7 @@ class Recorder(): self.full_timer.toc() logger.info("Overall time cost: {}".format(str(self.full_timer.total_time))) gc.collect() + self.full_timer = None class Trainer(Recorder): @@ -261,7 +262,7 @@ class DartsTrainer(Trainer): class OneShotTrainer(Trainer): - def __init__(self, supernet, criterion, optimizer, lr_scheduler, train_loader, test_loader, train_sampler, evaluate_sampler): + def __init__(self, supernet, criterion, optimizer, lr_scheduler, train_loader, test_loader, sample_type='epoch'): super().__init__( model=supernet, criterion=criterion, @@ -269,11 +270,15 @@ class OneShotTrainer(Trainer): lr_scheduler=lr_scheduler, train_loader=train_loader, test_loader=test_loader) - self.train_sampler = train_sampler - self.evaluate_sampler = evaluate_sampler + self.iter_sampler = None + self.sample_type = sample_type + assert self.sample_type in ['epoch', 'iter'] self.evaluate_meter = meter.TestMeter(len(self.test_loader)) + + def register_iter_sample(self, sampler): + self.iter_sampler = sampler - def train_epoch(self, cur_epoch): + def train_epoch(self, cur_epoch, sample=None): """Sample path from supernet and train it.""" self.model.train() lr = self.lr_scheduler.get_last_lr()[0] @@ -283,8 +288,9 @@ class OneShotTrainer(Trainer): for cur_iter, (inputs, labels) in enumerate(self.train_loader): inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) # sample subnet - choice = self.train_sampler.suggest() - preds = self.model(inputs, choice) + if self.sample_type == 'iter': + sample = self.iter_sampler.suggest() + preds = self.model(inputs, sample) loss = self.criterion(preds, labels) self.optimizer.zero_grad() loss.backward() @@ -294,7 +300,8 @@ class OneShotTrainer(Trainer): # Compute the errors top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item() - self.train_sampler.record(choice, top1_err) # use top1_err as evaluation + if self.sample_type == 'iter': + self.iter_sampler.record(sample, top1_err) # use top1_err as evaluation self.train_meter.iter_toc() # Update and log stats self.train_meter.update_stats(top1_err, top5_err, loss, lr, inputs.size(0)) @@ -311,19 +318,22 @@ class OneShotTrainer(Trainer): # Saving checkpoint if (cur_epoch + 1) % cfg.SAVE_PERIOD == 0: self.saving(cur_epoch) + return top1_err @torch.no_grad() - def test_epoch(self, cur_epoch): + def test_epoch(self, cur_epoch, sample=None): self.model.eval() self.test_meter.iter_tic() for cur_iter, (inputs, labels) in enumerate(self.test_loader): inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) # sample subnet - choice = self.train_sampler.suggest() - preds = self.model(inputs, choice) + if self.sample_type == 'iter': + sample = self.iter_sampler.suggest() + preds = self.model(inputs, sample) top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) top1_err, top5_err = top1_err.item(), top5_err.item() - self.train_sampler.record(choice, top1_err) + if self.sample_type == 'iter': + self.iter_sampler.record(sample, top1_err) # use top1_err as evaluation self.test_meter.iter_toc() self.test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) self.test_meter.log_iter_stats(cur_epoch, cur_iter) @@ -338,20 +348,20 @@ class OneShotTrainer(Trainer): if self.best_err > top1_err: self.best_err = top1_err self.saving(cur_epoch, best=True) + return top1_err @torch.no_grad() - def best_arch(self, cycles): - """Return final best subnet architecture and its performance""" + def evaluate_epoch(self, sample): + """Return performance of the given sample (subnet)""" self.model.eval() - for c in range(cycles): - choice = self.evaluate_sampler.suggest() - for cur_iter, (inputs, labels) in enumerate(self.test_loader): - inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) - preds = self.model(inputs, choice) - top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) - top1_err, top5_err = top1_err.item(), top5_err.item() - self.evaluate_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) - top1_err = self.evaluate_meter.mb_top1_err.get_win_median() - self.evaluate_sampler.record(choice, top1_err) - self.evaluate_meter.reset() - return self.evaluate_sampler.final_best() + # choice = self.evaluate_sampler.suggest() + for cur_iter, (inputs, labels) in enumerate(self.test_loader): + inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) + preds = self.model(inputs, sample) + top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) + top1_err, top5_err = top1_err.item(), top5_err.item() + self.evaluate_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) + top1_err = self.evaluate_meter.mb_top1_err.get_win_median() + # self.evaluate_sampler.record(choice, top1_err) + self.evaluate_meter.reset() + return top1_err diff --git a/xnas/spaces/DARTS/cnn.py b/xnas/spaces/DARTS/cnn.py index a5bc3cf..dd613b4 100644 --- a/xnas/spaces/DARTS/cnn.py +++ b/xnas/spaces/DARTS/cnn.py @@ -100,6 +100,9 @@ class DartsCNN(nn.Module): self.norm_node_index = self._node_index(n_nodes, input_nodes=2, start_index=0) self.reduce_node_index = self._node_index(n_nodes, input_nodes=2, start_index=self.num_edges) + def weights(self): + return self.parameters() + def forward(self, x, sample): s0 = s1 = self.stem(x) diff --git a/xnas/spaces/SPOS/cnn.py b/xnas/spaces/SPOS/cnn.py index 64a7ddf..9d0e11e 100644 --- a/xnas/spaces/SPOS/cnn.py +++ b/xnas/spaces/SPOS/cnn.py @@ -176,6 +176,9 @@ class SPOS_supernet(nn.Module): # self.global_pooling = nn.AvgPool2d(7) self.classifier = nn.Linear(LAST_CHANNEL, self.classes, bias=False) self._initialize_weights() + + def weights(self): + return self.parameters() def forward(self, x, choice=np.random.randint(4, size=20)): x = self.stem(x) -- 2.34.1 From d23c8a5870f7748fbd9374abd12ae5be9ff9038e Mon Sep 17 00:00:00 2001 From: xfey Date: Tue, 24 May 2022 15:58:00 +0800 Subject: [PATCH 4/5] OFA added and some bugs fixed. --- README.md | 4 +- scripts/search/DrNAS.py | 3 +- scripts/search/OFA/train_supernet.py | 101 ++ xnas/algorithms/OFA/progressive_shrinking.py | 166 +++ xnas/core/builder.py | 6 + xnas/datasets/imagenet.py | 10 +- xnas/datasets/loader.py | 4 + xnas/runner/trainer.py | 8 +- xnas/spaces/OFA/MobileNetV3/cnn.py | 395 ++++++ xnas/spaces/OFA/MobileNetV3/ofa_cnn.py | 415 ++++++ xnas/spaces/OFA/ProxylessNet/cnn.py | 242 ++++ xnas/spaces/OFA/ProxylessNet/ofa_cnn.py | 390 ++++++ xnas/spaces/OFA/ResNets/cnn.py | 248 ++++ xnas/spaces/OFA/ResNets/ofa_cnn.py | 353 ++++++ xnas/spaces/OFA/dynamic_ops.py | 1182 ++++++++++++++++++ xnas/spaces/OFA/ops.py | 957 ++++++++++++++ xnas/spaces/OFA/utils.py | 111 ++ 17 files changed, 4584 insertions(+), 11 deletions(-) create mode 100644 scripts/search/OFA/train_supernet.py create mode 100644 xnas/algorithms/OFA/progressive_shrinking.py create mode 100644 xnas/spaces/OFA/MobileNetV3/cnn.py create mode 100644 xnas/spaces/OFA/MobileNetV3/ofa_cnn.py create mode 100644 xnas/spaces/OFA/ProxylessNet/cnn.py create mode 100644 xnas/spaces/OFA/ProxylessNet/ofa_cnn.py create mode 100644 xnas/spaces/OFA/ResNets/cnn.py create mode 100644 xnas/spaces/OFA/ResNets/ofa_cnn.py create mode 100644 xnas/spaces/OFA/dynamic_ops.py create mode 100644 xnas/spaces/OFA/ops.py create mode 100644 xnas/spaces/OFA/utils.py diff --git a/README.md b/README.md index 7463c5c..5daa51d 100644 --- a/README.md +++ b/README.md @@ -67,13 +67,13 @@ We are gradually providing support for more settings. To run XNAS, `python>=3.7` and `pytorch=1.9` are required. Other versions of `PyTorch` may also work well, but there are potential API differences that can cause warnings to be generated. -For detailed instructions, please refer to [**Getting_started.md**](./docs/Getting_started.md) and [**Data_preparation.md**](./docs/Data_preparation.md) in our docs. +For detailed instructions, please refer to [**Getting_started.md**](./docs/get_started.md) and [**Data_preparation.md**](./docs/data_preparation.md) in our docs. ## Contributing We welcome contributions to the library along with any potential issues or suggestions. -Please refer to [**Contributing.md**](./docs/Contributing.md) in our docs for more information. +Please refer to [**Contributing.md**](./docs/notes.md) in our docs for more information. ## Citation diff --git a/scripts/search/DrNAS.py b/scripts/search/DrNAS.py index d2859f6..ee7fd07 100755 --- a/scripts/search/DrNAS.py +++ b/scripts/search/DrNAS.py @@ -81,7 +81,7 @@ def main(): # check whether warm-up training is used if cfg.LOADER.BATCH_SIZE <= 256 or cfg.LOADER.DATASET != 'imagenet': cfg.OPTIM.WARMUP_EPOCH = 0 # DrNAS does not warm-up if batch_size is small - train_epochs[0] -= cfg.OPTIM.WARMUP_EPOCH + train_epochs[0] += cfg.OPTIM.WARMUP_EPOCH # init recorders drnas_trainer = DartsTrainer( @@ -120,6 +120,7 @@ def main(): writer=drnas_trainer.writer, cur_epoch=cur_epoch ) + start_epoch += 1 # set tau for snas & gdas if TAU_FLAG: tau_epoch += tau_step diff --git a/scripts/search/OFA/train_supernet.py b/scripts/search/OFA/train_supernet.py new file mode 100644 index 0000000..896e3d3 --- /dev/null +++ b/scripts/search/OFA/train_supernet.py @@ -0,0 +1,101 @@ +"""OFA supernet training.""" + +import xnas.core.config as config +import xnas.logger.logging as logging +from xnas.core.config import cfg +from xnas.core.builder import * + +# OFA +from xnas.runner.trainer import Trainer +from xnas.spaces.OFA.utils import init_model +from xnas.algorithms.OFA.progressive_shrinking import train_epoch, validate + + +# Load config and check +config.load_configs() +logger = logging.get_logger(__name__) + +def main(): + setup_env() + # Network + net = space_builder().cuda() + init_model(net) + # Loss function + criterion = criterion_builder() + # Data loaders + [train_loader, valid_loader] = construct_loader() + # Optimizers + net_params = [ + # parameters with weight decay + {"params": net.get_parameters(['bn', 'bias'], mode="exclude"), "weight_decay": cfg.OPTIM.WEIGHT_DECAY}, + # parameters without weight decay + {"params": net.get_parameters(['bn', 'bias'], mode="include") , "weight_decay": 0}, + ] + # init_lr = cfg.OPTIM.BASE_LR * cfg.NUM_GPUS # TODO: multi-GPU support + optimizer = optimizer_builder("SGD", net_params) + lr_scheduler = lr_scheduler_builder(optimizer) + + # Initialize Recorder + ofa_trainer = Trainer( + model=net, + criterion=criterion, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + train_loader=train_loader, + test_loader=valid_loader, + ) + # Resume + start_epoch = ofa_trainer.loading() if cfg.SEARCH.AUTO_RESUME else 0 + + # build validation config + validate_func_dict = { + "image_size_list": {cfg.TEST.IM_SIZE}, + "ks_list": sorted({min(net.ks_list), max(net.ks_list)}), + "expand_ratio_list": sorted({min(net.expand_ratio_list), max(net.expand_ratio_list)}), + "depth_list": sorted({min(net.depth_list), max(net.depth_list)}), + } + if cfg.OFA.TASK == 'normal': + pass + elif cfg.OFA.TASK == 'kernel': + validate_func_dict["ks_list"] = sorted(net.ks_list) + elif cfg.OFA.TASK == 'depth': + # add depth list constraints + if (len(set(net.ks_list)) == 1) and (len(set(net.expand_ratio_list)) == 1): + validate_func_dict["depth_list"] = net.depth_list + elif cfg.OFA.TASK == 'expand': + if len(set(net.ks_list)) == 1 and len(set(net.depth_list)) == 1: + validate_func_dict["expand_ratio_list"] = net.expand_ratio_list + else: + raise NotImplementedError + + # Training + logger.info("=== OFA | Task: {} | Phase: {} ===".format(cfg.OFA.TASK, cfg.OFA.PHASE)) + ofa_trainer.start() + for cur_epoch in range(start_epoch, cfg.OPTIM.WARMUP_EPOCH+cfg.OPTIM.MAX_EPOCH): + train_epoch( + train_=train_loader, + net=net, + train_criterion=criterion, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + writer=ofa_trainer.writer, + train_meter=ofa_trainer.train_meter, + cur_epoch=cur_epoch, + ) + if (cur_epoch + 1) % cfg.SAVE_PERIOD == 0: + ofa_trainer.saving(cur_epoch) + if (cur_epoch+1) % cfg.EVAL_PERIOD == 0 or (cur_epoch+1) == cfg.OPTIM.MAX_EPOCH: + validate( + val_=valid_loader, + net=net, + val_meter=ofa_trainer.test_meter, + cur_epoch=cur_epoch, + logger=logger, + **validate_func_dict, + ) + # TODO:OFA支持保存validation accuracy最高的checkpoint + ofa_trainer.finish() + + +if __name__ == '__main__': + main() diff --git a/xnas/algorithms/OFA/progressive_shrinking.py b/xnas/algorithms/OFA/progressive_shrinking.py new file mode 100644 index 0000000..775e252 --- /dev/null +++ b/xnas/algorithms/OFA/progressive_shrinking.py @@ -0,0 +1,166 @@ +import random +import torch +import torch.nn as nn + +import xnas.logger.meter as meter +from xnas.core.config import cfg +from xnas.spaces.OFA.utils import list_mean + +__all__ = [ + "validate", + "train_one_epoch", +] + + +def validate( + val_, net, val_meter, + cur_epoch, logger, + + image_size_list=None, + ks_list=None, + expand_ratio_list=None, + depth_list=None, + width_mult_list=None, + additional_setting=None, +): + # net + dynamic_net = net + if isinstance(dynamic_net, nn.DataParallel): + dynamic_net = dynamic_net.module + # eval mode + dynamic_net.eval() + + # net config + assert image_size_list is not None, 'validate: image_size should not be None' + + if ks_list is None: + ks_list = dynamic_net.ks_list + if expand_ratio_list is None: + expand_ratio_list = dynamic_net.expand_ratio_list + if depth_list is None: + depth_list = dynamic_net.depth_list + if width_mult_list is None: + if "width_mult_list" in dynamic_net.__dict__: + width_mult_list = list(range(len(dynamic_net.width_mult_list))) + else: + width_mult_list = [0] + + + # 获取所有subnet的setting + subnet_settings = [] + # img_size = cfg.TEST.IM_SIZE + for d in depth_list: + for e in expand_ratio_list: + for k in ks_list: + for w in width_mult_list: + for img_size in image_size_list: + subnet_settings.append( + [ + { + "img_size": img_size, + "d": d, + "e": e, + "ks": k, + "w": w, + }, + "R%s-D%s-E%s-K%s-W%s" % (img_size, d, e, k, w), + ] + ) + if additional_setting is not None: + subnet_settings += additional_setting + + # 遍历评估所有subnet + for setting, name in subnet_settings: + dynamic_net.set_active_subnet(**setting) + logger.info('epoch: '+str(cur_epoch+1) + ' || validate subnet: '+ name) + test_epoch(val_, dynamic_net, val_meter, cur_epoch) + + +def train_epoch( + train_, net, train_criterion, + optimizer, lr_scheduler, writer, train_meter, + cur_epoch, + ): + + nBatch = len(train_) + cur_step = cur_epoch*nBatch + + for cur_iter, (images, labels) in enumerate(train_): + cur_step += 1 + net.train() + train_meter.iter_tic() # 初始化时间 + + images, labels = images.cuda(), labels.cuda() + + # clean gradients + net.zero_grad() + + # set random seed before sampling + subnet_seed = int("%d%.3d" % (cur_step, 0)) + random.seed(subnet_seed) + # subset setting + subnet_settings = net.sample_active_subnet() + subnet_str = ",".join( + [ + "%s_%s" + % ( + key, + "%.1f" % list_mean([val[0]]) + if isinstance(val, list) + else val, + ) + for key, val in subnet_settings.items() + ] + ) + # compute output + output = net(images) + loss = train_criterion(output, labels) + loss.backward() + optimizer.step() + + # measure top1&top5 error + top1_err, top5_err = meter.topk_errors(output, labels, [1, 5]) + # Copy the stats from GPU to CPU (sync point) + loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item() + train_meter.iter_toc() + # Update and log stats + mb_size = images.size(0) + cur_lr = lr_scheduler.get_last_lr()[0] + train_meter.update_stats(top1_err, top5_err, loss, cur_lr, mb_size) + train_meter.log_iter_stats(cur_epoch, cur_iter) + train_meter.iter_tic() + # write to tensorboard + writer.add_scalar('train/lr', cur_lr, cur_step) + writer.add_scalar('train/loss', loss, cur_step) + writer.add_scalar('train/top1_error', top1_err, cur_step) + writer.add_scalar('train/top5_error', top5_err, cur_step) + # update lr + lr_scheduler.step(cur_epoch + cur_iter / nBatch) + # Log epoch stats + train_meter.log_epoch_stats(cur_epoch) + train_meter.reset() + return + + +@torch.no_grad() +def test_epoch(test_loader, model, test_meter, cur_epoch, writer=None): + model.eval() + test_meter.iter_tic() + for cur_iter, (inputs, labels) in enumerate(test_loader): + inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) + preds = model(inputs) + top1_err, top5_err = meter.topk_errors(preds, labels, [1, 5]) + top1_err, top5_err = top1_err.item(), top5_err.item() + + test_meter.iter_toc() + test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) + test_meter.log_iter_stats(cur_epoch, cur_iter) + test_meter.iter_tic() + top1_err = test_meter.mb_top1_err.get_win_median() + if writer is not None: + writer.add_scalar('val/top1_error', test_meter.mb_top1_err.get_win_median(), cur_epoch) + writer.add_scalar('val/top5_error', test_meter.mb_top5_err.get_win_median(), cur_epoch) + # Log epoch stats + test_meter.log_epoch_stats(cur_epoch) + test_meter.reset() + return top1_err diff --git a/xnas/core/builder.py b/xnas/core/builder.py index d7491bf..830eb05 100644 --- a/xnas/core/builder.py +++ b/xnas/core/builder.py @@ -50,6 +50,9 @@ from xnas.spaces.DrNAS.darts_cnn import _DrNAS_DARTS_CNN from xnas.spaces.DrNAS.nb201_cnn import _DrNAS_nb201_CNN, _GDAS_nb201_CNN from xnas.spaces.SPOS.cnn import _SPOS_CNN, _infer_SPOS_CNN from xnas.spaces.DropNAS.cnn import _DropNASCNN +from xnas.spaces.OFA.MobileNetV3.ofa_cnn import _OFAMobileNetV3 +from xnas.spaces.OFA.ProxylessNet.ofa_cnn import _OFAProxylessNASNet +from xnas.spaces.OFA.ResNets.ofa_cnn import _OFAResNet SUPPORTED_SPACES = { @@ -62,6 +65,9 @@ SUPPORTED_SPACES = { "gdas_nb201": _GDAS_nb201_CNN, "dropnas": _DropNASCNN, "spos": _SPOS_CNN, + "ofa_mbv3": _OFAMobileNetV3, + "ofa_proxyless": _OFAProxylessNASNet, + "ofa_resnet": _OFAResNet, # models for inference "infer_darts": _infer_DartsCNN, "infer_nb201": _infer_NASBench201, diff --git a/xnas/datasets/imagenet.py b/xnas/datasets/imagenet.py index ffff037..6ebcc92 100644 --- a/xnas/datasets/imagenet.py +++ b/xnas/datasets/imagenet.py @@ -33,8 +33,8 @@ class ImageFolder(): _rgb_normalized_mean=None, _rgb_normalized_std=None, transforms=None, - num_workers=cfg.LOADER.NUM_WORKERS, - pin_memory=cfg.LOADER.PIN_MEMORY, + num_workers=None, + pin_memory=None, shuffle=True ): assert os.path.exists(datapath), "Data path '{}' not found".format(datapath) @@ -42,10 +42,10 @@ class ImageFolder(): self._data_path, self._split, self.dataset_name = datapath, split, dataset_name self._rgb_normalized_mean, self._rgb_normalized_std = _rgb_normalized_mean, _rgb_normalized_std - self.num_workers = num_workers - self.batch_size = batch_size - self.pin_memory = pin_memory + self.num_workers = cfg.LOADER.NUM_WORKERS if num_workers is None else num_workers + self.pin_memory = cfg.LOADER.PIN_MEMORY if pin_memory is None else pin_memory self.shuffle = shuffle + self.batch_size = batch_size if transforms is None: self.transforms = [{'crop': 'random', 'crop_size': cfg.SEARCH.IM_SIZE, 'min_crop': 0.08, 'random_flip': True}, {'crop': 'center', 'crop_size': cfg.TEST.IM_SIZE, 'min_crop': -1, 'random_flip': False}] # NOTE: min_crop is not used here. diff --git a/xnas/datasets/loader.py b/xnas/datasets/loader.py index 92a1867..a571dae 100644 --- a/xnas/datasets/loader.py +++ b/xnas/datasets/loader.py @@ -43,6 +43,10 @@ def construct_loader( batch_size = [256, 256] assert len(batch_size) == len(split), "lengths of batch_size and split should be same." + # check if randomresized crop is used only in ImageFolder type datasets + if isinstance(cfg.SEARCH.IMG_SIZE, list): + assert name in IMAGEFOLDER_FORMAT, "RandomResizedCrop can only be used in ImageFolder currently." + if name in SUPPORTED_DATASETS: # using training data only. train_data, _ = get_data(name, datapath, cutout_length, use_classes=use_classes, transforms=transforms) diff --git a/xnas/runner/trainer.py b/xnas/runner/trainer.py index a0786ca..2882e00 100644 --- a/xnas/runner/trainer.py +++ b/xnas/runner/trainer.py @@ -22,7 +22,7 @@ from xnas.core.config import cfg from torch.utils.tensorboard import SummaryWriter -__all__ = ["Trainer", "DartsTrainer"] +__all__ = ["Trainer", "DartsTrainer", "OneShotTrainer"] logger = logging.get_logger(__name__) @@ -155,8 +155,10 @@ class Trainer(Recorder): """Load from checkpoint.""" ckpt_epoch, ckpt_dict = self.resume() if ckpt_epoch != -1: - self.optimizer.load_state_dict(ckpt_dict['optimizer']) - self.lr_scheduler.load_state_dict(ckpt_dict['lr_scheduler']) + if self.optimizer is not None: + self.optimizer.load_state_dict(ckpt_dict['optimizer']) + if self.lr_scheduler is not None: + self.lr_scheduler.load_state_dict(ckpt_dict['lr_scheduler']) return ckpt_epoch + 1 else: return 0 diff --git a/xnas/spaces/OFA/MobileNetV3/cnn.py b/xnas/spaces/OFA/MobileNetV3/cnn.py new file mode 100644 index 0000000..a497cb6 --- /dev/null +++ b/xnas/spaces/OFA/MobileNetV3/cnn.py @@ -0,0 +1,395 @@ +import torch +import torch.nn as nn +from copy import deepcopy + +from xnas.spaces.OFA.ops import * +from xnas.spaces.OFA.utils import min_divisible_value + + +__all__ = ["WSConv_Network", "MobileNetV3", "MobileNetV3Large"] + + +class WSConv_Network(nn.Module): + """Network with all Conv2d replaced by Weight Standard Conv2d.""" + + def set_bn_param(self, momentum, eps, gn_channel_per_group=None, ws_eps=None, **kwargs): + + """Replace BN with GN""" + if gn_channel_per_group is None: + return + + for m in self.modules(): + to_replace_dict = {} + for name, sub_m in m.named_children(): + if isinstance(sub_m, nn.BatchNorm2d): + num_groups = sub_m.num_features // min_divisible_value( + sub_m.num_features, gn_channel_per_group + ) + gn_m = nn.GroupNorm( + num_groups=num_groups, + num_channels=sub_m.num_features, + eps=sub_m.eps, + affine=True, + ) + + # load weight + gn_m.weight.data.copy_(sub_m.weight.data) + gn_m.bias.data.copy_(sub_m.bias.data) + # load requires_grad + gn_m.weight.requires_grad = sub_m.weight.requires_grad + gn_m.bias.requires_grad = sub_m.bias.requires_grad + + to_replace_dict[name] = gn_m + m._modules.update(to_replace_dict) + + """Init Norm params""" + for m in self.modules(): + if type(m) in [nn.BatchNorm1d, nn.BatchNorm2d]: + m.momentum = momentum + m.eps = eps + elif isinstance(m, nn.GroupNorm): + m.eps = eps + + """Replace Conv2d with WeightStandardConv2d""" + if ws_eps is None: + return + + for m in self.modules(): + to_update_dict = {} + for name, sub_module in m.named_children(): + if isinstance(sub_module, nn.Conv2d) and not sub_module.bias: + # only replace conv2d layers that are followed by normalization layers (i.e., no bias) + to_update_dict[name] = sub_module + for name, sub_module in to_update_dict.items(): + m._modules[name] = WeightStandardConv2d( + sub_module.in_channels, + sub_module.out_channels, + sub_module.kernel_size, + sub_module.stride, + sub_module.padding, + sub_module.dilation, + sub_module.groups, + sub_module.bias, + ) + # load weight + m._modules[name].load_state_dict(sub_module.state_dict()) + # load requires_grad + m._modules[name].weight.requires_grad = sub_module.weight.requires_grad + if sub_module.bias is not None: + m._modules[name].bias.requires_grad = sub_module.bias.requires_grad + # set ws_eps + for m in self.modules(): + if isinstance(m, WeightStandardConv2d): + m.WS_EPS = ws_eps + + def get_bn_param(self): + ws_eps = None + for m in self.modules(): + if isinstance(m, WeightStandardConv2d): + ws_eps = m.WS_EPS + break + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + return { + "momentum": m.momentum, + "eps": m.eps, + "ws_eps": ws_eps, + } + elif isinstance(m, nn.GroupNorm): + return { + "momentum": None, + "eps": m.eps, + "gn_channel_per_group": m.num_channels // m.num_groups, + "ws_eps": ws_eps, + } + return None + + def get_parameters(self, keys=None, mode="include"): + if keys is None: + for name, param in self.named_parameters(): + if param.requires_grad: + yield param + elif mode == "include": + for name, param in self.named_parameters(): + flag = False + for key in keys: + if key in name: + flag = True + break + if flag and param.requires_grad: + yield param + elif mode == "exclude": + for name, param in self.named_parameters(): + flag = True + for key in keys: + if key in name: + flag = False + break + if flag and param.requires_grad: + yield param + else: + raise ValueError("do not support: %s" % mode) + + +class MobileNetV3(WSConv_Network): + def __init__( + self, first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ): + super(MobileNetV3, self).__init__() + + self.first_conv = first_conv + self.blocks = nn.ModuleList(blocks) + self.final_expand_layer = final_expand_layer + self.global_avg_pool = GlobalAvgPool2d(keep_dim=True) + self.feature_mix_layer = feature_mix_layer + self.classifier = classifier + + def forward(self, x): + x = self.first_conv(x) + for block in self.blocks: + x = block(x) + x = self.final_expand_layer(x) + x = self.global_avg_pool(x) # global average pooling + x = self.feature_mix_layer(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = self.first_conv.module_str + "\n" + for block in self.blocks: + _str += block.module_str + "\n" + _str += self.final_expand_layer.module_str + "\n" + _str += self.global_avg_pool.__repr__() + "\n" + _str += self.feature_mix_layer.module_str + "\n" + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + "name": MobileNetV3.__name__, + "bn": self.get_bn_param(), + "first_conv": self.first_conv.config, + "blocks": [block.config for block in self.blocks], + "final_expand_layer": self.final_expand_layer.config, + "feature_mix_layer": self.feature_mix_layer.config, + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + first_conv = set_layer_from_config(config["first_conv"]) + final_expand_layer = set_layer_from_config(config["final_expand_layer"]) + feature_mix_layer = set_layer_from_config(config["feature_mix_layer"]) + classifier = set_layer_from_config(config["classifier"]) + + blocks = [] + for block_config in config["blocks"]: + blocks.append(ResidualBlock.build_from_config(block_config)) + + net = MobileNetV3( + first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ) + if "bn" in config: + net.set_bn_param(**config["bn"]) + else: + net.set_bn_param(momentum=0.1, eps=1e-5) + + return net + + def zero_last_gamma(self): + for m in self.modules(): + if isinstance(m, ResidualBlock): + if isinstance(m.conv, MBConvLayer) and isinstance( + m.shortcut, IdentityLayer + ): + m.conv.point_linear.bn.weight.data.zero_() + + @property + def grouped_block_index(self): + info_list = [] + block_index_list = [] + for i, block in enumerate(self.blocks[1:], 1): + if block.shortcut is None and len(block_index_list) > 0: + info_list.append(block_index_list) + block_index_list = [] + block_index_list.append(i) + if len(block_index_list) > 0: + info_list.append(block_index_list) + return info_list + + @staticmethod + def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate): + # first conv layer + first_conv = ConvLayer( + 3, + input_channel, + kernel_size=3, + stride=2, + use_bn=True, + act_func="h_swish", + ops_order="weight_bn_act", + ) + # build mobile blocks + feature_dim = input_channel + blocks = [] + for stage_id, block_config_list in cfg.items(): + for ( + k, + mid_channel, + out_channel, + use_se, + act_func, + stride, + expand_ratio, + ) in block_config_list: + mb_conv = MBConvLayer( + feature_dim, + out_channel, + k, + stride, + expand_ratio, + mid_channel, + act_func, + use_se, + ) + if stride == 1 and out_channel == feature_dim: + shortcut = IdentityLayer(out_channel, out_channel) + else: + shortcut = None + blocks.append(ResidualBlock(mb_conv, shortcut)) + feature_dim = out_channel + # final expand layer + final_expand_layer = ConvLayer( + feature_dim, + feature_dim * 6, + kernel_size=1, + use_bn=True, + act_func="h_swish", + ops_order="weight_bn_act", + ) + # feature mix layer + feature_mix_layer = ConvLayer( + feature_dim * 6, + last_channel, + kernel_size=1, + bias=False, + use_bn=False, + act_func="h_swish", + ) + # classifier + classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) + + return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + + @staticmethod + def adjust_cfg( + cfg, ks=None, expand_ratio=None, depth_param=None, stage_width_list=None + ): + for i, (stage_id, block_config_list) in enumerate(cfg.items()): + for block_config in block_config_list: + if ks is not None and stage_id != "0": + block_config[0] = ks + if expand_ratio is not None and stage_id != "0": + block_config[-1] = expand_ratio + block_config[1] = None + if stage_width_list is not None: + block_config[2] = stage_width_list[i] + if depth_param is not None and stage_id != "0": + new_block_config_list = [block_config_list[0]] + new_block_config_list += [ + deepcopy(block_config_list[-1]) for _ in range(depth_param - 1) + ] + cfg[stage_id] = new_block_config_list + return cfg + + def load_state_dict(self, state_dict, **kwargs): + current_state_dict = self.state_dict() + + for key in state_dict: + if key not in current_state_dict: + assert ".mobile_inverted_conv." in key + new_key = key.replace(".mobile_inverted_conv.", ".conv.") + else: + new_key = key + current_state_dict[new_key] = state_dict[key] + super(MobileNetV3, self).load_state_dict(current_state_dict) + + +class MobileNetV3Large(MobileNetV3): + def __init__( + self, + n_classes=1000, + width_mult=1.0, + bn_param=(0.1, 1e-5), + dropout_rate=0.2, + ks=None, + expand_ratio=None, + depth_param=None, + stage_width_list=None, + ): + input_channel = 16 + last_channel = 1280 + + input_channel = make_divisible(input_channel * width_mult) + last_channel = ( + make_divisible(last_channel * width_mult) + if width_mult > 1.0 + else last_channel + ) + + cfg = { + # k, exp, c, se, nl, s, e, + "0": [ + [3, 16, 16, False, "relu", 1, 1], + ], + "1": [ + [3, 64, 24, False, "relu", 2, None], # 4 + [3, 72, 24, False, "relu", 1, None], # 3 + ], + "2": [ + [5, 72, 40, True, "relu", 2, None], # 3 + [5, 120, 40, True, "relu", 1, None], # 3 + [5, 120, 40, True, "relu", 1, None], # 3 + ], + "3": [ + [3, 240, 80, False, "h_swish", 2, None], # 6 + [3, 200, 80, False, "h_swish", 1, None], # 2.5 + [3, 184, 80, False, "h_swish", 1, None], # 2.3 + [3, 184, 80, False, "h_swish", 1, None], # 2.3 + ], + "4": [ + [3, 480, 112, True, "h_swish", 1, None], # 6 + [3, 672, 112, True, "h_swish", 1, None], # 6 + ], + "5": [ + [5, 672, 160, True, "h_swish", 2, None], # 6 + [5, 960, 160, True, "h_swish", 1, None], # 6 + [5, 960, 160, True, "h_swish", 1, None], # 6 + ], + } + + cfg = self.adjust_cfg(cfg, ks, expand_ratio, depth_param, stage_width_list) + # width multiplier on mobile setting, change `exp: 1` and `c: 2` + for stage_id, block_config_list in cfg.items(): + for block_config in block_config_list: + if block_config[1] is not None: + block_config[1] = make_divisible(block_config[1] * width_mult) + block_config[2] = make_divisible(block_config[2] * width_mult) + + ( + first_conv, + blocks, + final_expand_layer, + feature_mix_layer, + classifier, + ) = self.build_net_via_cfg( + cfg, input_channel, last_channel, n_classes, dropout_rate + ) + super(MobileNetV3Large, self).__init__( + first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ) + # set bn param + self.set_bn_param(*bn_param) diff --git a/xnas/spaces/OFA/MobileNetV3/ofa_cnn.py b/xnas/spaces/OFA/MobileNetV3/ofa_cnn.py new file mode 100644 index 0000000..da36382 --- /dev/null +++ b/xnas/spaces/OFA/MobileNetV3/ofa_cnn.py @@ -0,0 +1,415 @@ +import random +from copy import deepcopy + +from xnas.spaces.OFA.dynamic_ops import DynamicMBConvLayer +from xnas.spaces.OFA.utils import val2list, make_divisible +from xnas.spaces.OFA.ops import ( + ConvLayer, + IdentityLayer, + LinearLayer, + MBConvLayer, + ResidualBlock, +) +from xnas.spaces.OFA.MobileNetV3.cnn import MobileNetV3 + + +__all__ = ["_OFAMobileNetV3", "OFAMobileNetV3"] + + +class OFAMobileNetV3(MobileNetV3): + def __init__( + self, + n_classes=1000, + bn_param=(0.1, 1e-5), + dropout_rate=0.1, + base_stage_width=None, + width_mult=1.0, + ks_list=3, + expand_ratio_list=6, + depth_list=4, + ): + + self.width_mult = width_mult + self.ks_list = val2list(ks_list, 1) + self.expand_ratio_list = val2list(expand_ratio_list, 1) + self.depth_list = val2list(depth_list, 1) + + self.ks_list.sort() + self.expand_ratio_list.sort() + self.depth_list.sort() + + base_stage_width = [16, 16, 24, 40, 80, 112, 160, 960, 1280] + + final_expand_width = make_divisible(base_stage_width[-2] * self.width_mult) + last_channel = make_divisible(base_stage_width[-1] * self.width_mult) + + stride_stages = [1, 2, 2, 2, 1, 2] + act_stages = ["relu", "relu", "relu", "h_swish", "h_swish", "h_swish"] + se_stages = [False, False, True, False, True, True] + n_block_list = [1] + [max(self.depth_list)] * 5 + width_list = [] + for base_width in base_stage_width[:-2]: + width = make_divisible(base_width * self.width_mult) + width_list.append(width) + + input_channel, first_block_dim = width_list[0], width_list[1] + # first conv layer + first_conv = ConvLayer( + 3, input_channel, kernel_size=3, stride=2, act_func="h_swish" + ) + first_block_conv = MBConvLayer( + in_channels=input_channel, + out_channels=first_block_dim, + kernel_size=3, + stride=stride_stages[0], + expand_ratio=1, + act_func=act_stages[0], + use_se=se_stages[0], + ) + first_block = ResidualBlock( + first_block_conv, + IdentityLayer(first_block_dim, first_block_dim) + if input_channel == first_block_dim + else None, + ) + + # inverted residual blocks + self.block_group_info = [] + blocks = [first_block] + _block_index = 1 + feature_dim = first_block_dim + + for width, n_block, s, act_func, use_se in zip( + width_list[2:], + n_block_list[1:], + stride_stages[1:], + act_stages[1:], + se_stages[1:], + ): + self.block_group_info.append([_block_index + i for i in range(n_block)]) + _block_index += n_block + + output_channel = width + for i in range(n_block): + if i == 0: + stride = s + else: + stride = 1 + mobile_inverted_conv = DynamicMBConvLayer( + in_channel_list=val2list(feature_dim), + out_channel_list=val2list(output_channel), + kernel_size_list=ks_list, + expand_ratio_list=expand_ratio_list, + stride=stride, + act_func=act_func, + use_se=use_se, + ) + if stride == 1 and feature_dim == output_channel: + shortcut = IdentityLayer(feature_dim, feature_dim) + else: + shortcut = None + blocks.append(ResidualBlock(mobile_inverted_conv, shortcut)) + feature_dim = output_channel + # final expand layer, feature mix layer & classifier + final_expand_layer = ConvLayer( + feature_dim, final_expand_width, kernel_size=1, act_func="h_swish" + ) + feature_mix_layer = ConvLayer( + final_expand_width, + last_channel, + kernel_size=1, + bias=False, + use_bn=False, + act_func="h_swish", + ) + + classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) + + super(OFAMobileNetV3, self).__init__( + first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ) + + # set bn param + self.set_bn_param(momentum=bn_param[0], eps=bn_param[1]) + + # runtime_depth + self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info] + + """ WSConv_Network required methods """ + + @staticmethod + def name(): + return "OFAMobileNetV3" + + def forward(self, x): + # first conv + x = self.first_conv(x) + # first block + x = self.blocks[0](x) + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + x = self.blocks[idx](x) + x = self.final_expand_layer(x) + x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling + x = self.feature_mix_layer(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = self.first_conv.module_str + "\n" + _str += self.blocks[0].module_str + "\n" + + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + _str += self.blocks[idx].module_str + "\n" + + _str += self.final_expand_layer.module_str + "\n" + _str += self.feature_mix_layer.module_str + "\n" + _str += self.classifier.module_str + "\n" + return _str + + @property + def config(self): + return { + "name": OFAMobileNetV3.__name__, + "bn": self.get_bn_param(), + "first_conv": self.first_conv.config, + "blocks": [block.config for block in self.blocks], + "final_expand_layer": self.final_expand_layer.config, + "feature_mix_layer": self.feature_mix_layer.config, + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + raise ValueError("do not support this function") + + @property + def grouped_block_index(self): + return self.block_group_info + + def load_state_dict(self, state_dict, **kwargs): + model_dict = self.state_dict() + for key in state_dict: + if ".mobile_inverted_conv." in key: + new_key = key.replace(".mobile_inverted_conv.", ".conv.") + else: + new_key = key + if new_key in model_dict: + pass + elif ".bn.bn." in new_key: + new_key = new_key.replace(".bn.bn.", ".bn.") + elif ".conv.conv.weight" in new_key: + new_key = new_key.replace(".conv.conv.weight", ".conv.weight") + elif ".linear.linear." in new_key: + new_key = new_key.replace(".linear.linear.", ".linear.") + ############################################################################## + elif ".linear." in new_key: + new_key = new_key.replace(".linear.", ".linear.linear.") + elif "bn." in new_key: + new_key = new_key.replace("bn.", "bn.bn.") + elif "conv.weight" in new_key: + new_key = new_key.replace("conv.weight", "conv.conv.weight") + else: + raise ValueError(new_key) + assert new_key in model_dict, "%s" % new_key + model_dict[new_key] = state_dict[key] + super(OFAMobileNetV3, self).load_state_dict(model_dict) + + """ set, sample and get active sub-networks """ + + def set_max_net(self): + self.set_active_subnet( + ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list) + ) + + def set_active_subnet(self, ks=None, e=None, d=None, **kwargs): + ks = val2list(ks, len(self.blocks) - 1) + expand_ratio = val2list(e, len(self.blocks) - 1) + depth = val2list(d, len(self.block_group_info)) + + for block, k, e in zip(self.blocks[1:], ks, expand_ratio): + if k is not None: + block.conv.active_kernel_size = k + if e is not None: + block.conv.active_expand_ratio = e + + for i, d in enumerate(depth): + if d is not None: + self.runtime_depth[i] = min(len(self.block_group_info[i]), d) + + def set_constraint(self, include_list, constraint_type="depth"): + if constraint_type == "depth": + self.__dict__["_depth_include_list"] = include_list.copy() + elif constraint_type == "expand_ratio": + self.__dict__["_expand_include_list"] = include_list.copy() + elif constraint_type == "kernel_size": + self.__dict__["_ks_include_list"] = include_list.copy() + else: + raise NotImplementedError + + def clear_constraint(self): + self.__dict__["_depth_include_list"] = None + self.__dict__["_expand_include_list"] = None + self.__dict__["_ks_include_list"] = None + + def sample_active_subnet(self): + ks_candidates = ( + self.ks_list + if self.__dict__.get("_ks_include_list", None) is None + else self.__dict__["_ks_include_list"] + ) + expand_candidates = ( + self.expand_ratio_list + if self.__dict__.get("_expand_include_list", None) is None + else self.__dict__["_expand_include_list"] + ) + depth_candidates = ( + self.depth_list + if self.__dict__.get("_depth_include_list", None) is None + else self.__dict__["_depth_include_list"] + ) + + # sample kernel size + ks_setting = [] + if not isinstance(ks_candidates[0], list): + ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)] + for k_set in ks_candidates: + k = random.choice(k_set) + ks_setting.append(k) + + # sample expand ratio + expand_setting = [] + if not isinstance(expand_candidates[0], list): + expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)] + for e_set in expand_candidates: + e = random.choice(e_set) + expand_setting.append(e) + + # sample depth + depth_setting = [] + if not isinstance(depth_candidates[0], list): + depth_candidates = [ + depth_candidates for _ in range(len(self.block_group_info)) + ] + for d_set in depth_candidates: + d = random.choice(d_set) + depth_setting.append(d) + + self.set_active_subnet(ks_setting, expand_setting, depth_setting) + + return { + "ks": ks_setting, + "e": expand_setting, + "d": depth_setting, + } + + def get_active_subnet(self, preserve_weight=True): + first_conv = deepcopy(self.first_conv) + blocks = [deepcopy(self.blocks[0])] + + final_expand_layer = deepcopy(self.final_expand_layer) + feature_mix_layer = deepcopy(self.feature_mix_layer) + classifier = deepcopy(self.classifier) + + input_channel = blocks[0].conv.out_channels + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append( + ResidualBlock( + self.blocks[idx].conv.get_active_subnet( + input_channel, preserve_weight + ), + deepcopy(self.blocks[idx].shortcut), + ) + ) + input_channel = stage_blocks[-1].conv.out_channels + blocks += stage_blocks + + _subnet = MobileNetV3( + first_conv, blocks, final_expand_layer, feature_mix_layer, classifier + ) + _subnet.set_bn_param(**self.get_bn_param()) + return _subnet + + def get_active_net_config(self): + # first conv + first_conv_config = self.first_conv.config + first_block_config = self.blocks[0].config + final_expand_config = self.final_expand_layer.config + feature_mix_layer_config = self.feature_mix_layer.config + classifier_config = self.classifier.config + + block_config_list = [first_block_config] + input_channel = first_block_config["conv"]["out_channels"] + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append( + { + "name": ResidualBlock.__name__, + "conv": self.blocks[idx].conv.get_active_subnet_config( + input_channel + ), + "shortcut": self.blocks[idx].shortcut.config + if self.blocks[idx].shortcut is not None + else None, + } + ) + input_channel = self.blocks[idx].conv.active_out_channel + block_config_list += stage_blocks + + return { + "name": MobileNetV3.__name__, + "bn": self.get_bn_param(), + "first_conv": first_conv_config, + "blocks": block_config_list, + "final_expand_layer": final_expand_config, + "feature_mix_layer": feature_mix_layer_config, + "classifier": classifier_config, + } + + """ Width Related Methods """ + + def re_organize_middle_weights(self, expand_ratio_stage=0): + for block in self.blocks[1:]: + block.conv.re_organize_middle_weights(expand_ratio_stage) + + +def _OFAMobileNetV3(): + from xnas.core.config import cfg + + width_mult_list = cfg.OFA.WIDTH_MULTI_LIST + ks_list = cfg.OFA.KS_LIST + expand_list = cfg.OFA.EXPAND_LIST + depth_list = cfg.OFA.DEPTH_LIST + + width_mult_list = ( + cfg.OFA.WIDTH_MULTI_LIST[0] + if len(cfg.OFA.WIDTH_MULTI_LIST) == 1 + else cfg.OFA.WIDTH_MULTI_LIST + ) + + return OFAMobileNetV3( + n_classes=cfg.SEARCH.NUM_CLASSES, + bn_param=(0.1, 1e-5), + dropout_rate=0.1, + base_stage_width=None, + width_mult=width_mult_list, + ks_list=ks_list, + expand_ratio_list=expand_list, + depth_list=depth_list, + ) diff --git a/xnas/spaces/OFA/ProxylessNet/cnn.py b/xnas/spaces/OFA/ProxylessNet/cnn.py new file mode 100644 index 0000000..b8d54cb --- /dev/null +++ b/xnas/spaces/OFA/ProxylessNet/cnn.py @@ -0,0 +1,242 @@ +import json +import torch.nn as nn + +from xnas.spaces.OFA.ops import ( + set_layer_from_config, + MBConvLayer, + ConvLayer, + IdentityLayer, + LinearLayer, + ResidualBlock, + GlobalAvgPool2d, +) +from xnas.spaces.OFA.utils import val2list, make_divisible +from xnas.spaces.OFA.MobileNetV3.cnn import WSConv_Network + + +__all__ = ["proxyless_base", "ProxylessNASNet", "MobileNetV2"] + + +def proxyless_base( + net_config=None, + n_classes=None, + bn_param=None, + dropout_rate=None, +): + assert net_config is not None, "Please input a network config" + net_config_json = json.load(open(net_config, "r")) + + if n_classes is not None: + net_config_json["classifier"]["out_features"] = n_classes + if dropout_rate is not None: + net_config_json["classifier"]["dropout_rate"] = dropout_rate + + net = ProxylessNASNet.build_from_config(net_config_json) + if bn_param is not None: + net.set_bn_param(*bn_param) + + return net + + +class ProxylessNASNet(WSConv_Network): + def __init__(self, first_conv, blocks, feature_mix_layer, classifier): + super(ProxylessNASNet, self).__init__() + + self.first_conv = first_conv + self.blocks = nn.ModuleList(blocks) + self.feature_mix_layer = feature_mix_layer + self.global_avg_pool = GlobalAvgPool2d(keep_dim=False) + self.classifier = classifier + + def forward(self, x): + x = self.first_conv(x) + for block in self.blocks: + x = block(x) + if self.feature_mix_layer is not None: + x = self.feature_mix_layer(x) + x = self.global_avg_pool(x) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = self.first_conv.module_str + "\n" + for block in self.blocks: + _str += block.module_str + "\n" + _str += self.feature_mix_layer.module_str + "\n" + _str += self.global_avg_pool.__repr__() + "\n" + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + "name": ProxylessNASNet.__name__, + "bn": self.get_bn_param(), + "first_conv": self.first_conv.config, + "blocks": [block.config for block in self.blocks], + "feature_mix_layer": None + if self.feature_mix_layer is None + else self.feature_mix_layer.config, + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + first_conv = set_layer_from_config(config["first_conv"]) + feature_mix_layer = set_layer_from_config(config["feature_mix_layer"]) + classifier = set_layer_from_config(config["classifier"]) + + blocks = [] + for block_config in config["blocks"]: + blocks.append(ResidualBlock.build_from_config(block_config)) + + net = ProxylessNASNet(first_conv, blocks, feature_mix_layer, classifier) + if "bn" in config: + net.set_bn_param(**config["bn"]) + else: + net.set_bn_param(momentum=0.1, eps=1e-3) + + return net + + def zero_last_gamma(self): + for m in self.modules(): + if isinstance(m, ResidualBlock): + if isinstance(m.conv, MBConvLayer) and isinstance( + m.shortcut, IdentityLayer + ): + m.conv.point_linear.bn.weight.data.zero_() + + @property + def grouped_block_index(self): + info_list = [] + block_index_list = [] + for i, block in enumerate(self.blocks[1:], 1): + if block.shortcut is None and len(block_index_list) > 0: + info_list.append(block_index_list) + block_index_list = [] + block_index_list.append(i) + if len(block_index_list) > 0: + info_list.append(block_index_list) + return info_list + + def load_state_dict(self, state_dict, **kwargs): + current_state_dict = self.state_dict() + + for key in state_dict: + if key not in current_state_dict: + assert ".mobile_inverted_conv." in key + new_key = key.replace(".mobile_inverted_conv.", ".conv.") + else: + new_key = key + current_state_dict[new_key] = state_dict[key] + super(ProxylessNASNet, self).load_state_dict(current_state_dict) + + +class MobileNetV2(ProxylessNASNet): + def __init__( + self, + n_classes=1000, + width_mult=1.0, + bn_param=(0.1, 1e-3), + dropout_rate=0.2, + ks=None, + expand_ratio=None, + depth_param=None, + stage_width_list=None, + ): + + ks = 3 if ks is None else ks + expand_ratio = 6 if expand_ratio is None else expand_ratio + + input_channel = 32 + last_channel = 1280 + + input_channel = make_divisible(input_channel * width_mult) + last_channel = ( + make_divisible(last_channel * width_mult) + if width_mult > 1.0 + else last_channel + ) + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [expand_ratio, 24, 2, 2], + [expand_ratio, 32, 3, 2], + [expand_ratio, 64, 4, 2], + [expand_ratio, 96, 3, 1], + [expand_ratio, 160, 3, 2], + [expand_ratio, 320, 1, 1], + ] + + if depth_param is not None: + assert isinstance(depth_param, int) + for i in range(1, len(inverted_residual_setting) - 1): + inverted_residual_setting[i][2] = depth_param + + if stage_width_list is not None: + for i in range(len(inverted_residual_setting)): + inverted_residual_setting[i][1] = stage_width_list[i] + + ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1) + _pt = 0 + + # first conv layer + first_conv = ConvLayer( + 3, + input_channel, + kernel_size=3, + stride=2, + use_bn=True, + act_func="relu6", + ops_order="weight_bn_act", + ) + # inverted residual blocks + blocks = [] + for t, c, n, s in inverted_residual_setting: + output_channel = make_divisible(c * width_mult) + for i in range(n): + if i == 0: + stride = s + else: + stride = 1 + if t == 1: + kernel_size = 3 + else: + kernel_size = ks[_pt] + _pt += 1 + mobile_inverted_conv = MBConvLayer( + in_channels=input_channel, + out_channels=output_channel, + kernel_size=kernel_size, + stride=stride, + expand_ratio=t, + ) + if stride == 1: + if input_channel == output_channel: + shortcut = IdentityLayer(input_channel, input_channel) + else: + shortcut = None + else: + shortcut = None + blocks.append(ResidualBlock(mobile_inverted_conv, shortcut)) + input_channel = output_channel + # 1x1_conv before global average pooling + feature_mix_layer = ConvLayer( + input_channel, + last_channel, + kernel_size=1, + use_bn=True, + act_func="relu6", + ops_order="weight_bn_act", + ) + + classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) + + super(MobileNetV2, self).__init__( + first_conv, blocks, feature_mix_layer, classifier + ) + + # set bn param + self.set_bn_param(*bn_param) diff --git a/xnas/spaces/OFA/ProxylessNet/ofa_cnn.py b/xnas/spaces/OFA/ProxylessNet/ofa_cnn.py new file mode 100644 index 0000000..796d6da --- /dev/null +++ b/xnas/spaces/OFA/ProxylessNet/ofa_cnn.py @@ -0,0 +1,390 @@ +import random +from copy import deepcopy + +from xnas.spaces.OFA.dynamic_ops import DynamicMBConvLayer +from xnas.spaces.OFA.utils import val2list, make_divisible +from xnas.spaces.OFA.ops import ( + ConvLayer, + IdentityLayer, + LinearLayer, + MBConvLayer, + ResidualBlock, +) +from xnas.spaces.OFA.ProxylessNet.cnn import ProxylessNASNet + + +__all__ = ["_OFAProxylessNASNet", "OFAProxylessNASNet"] + + + +class OFAProxylessNASNet(ProxylessNASNet): + def __init__( + self, + n_classes=1000, + bn_param=(0.1, 1e-3), + dropout_rate=0.1, + base_stage_width=None, + width_mult=1.0, + ks_list=3, + expand_ratio_list=6, + depth_list=4, + ): + + self.width_mult = width_mult + self.ks_list = val2list(ks_list, 1) + self.expand_ratio_list = val2list(expand_ratio_list, 1) + self.depth_list = val2list(depth_list, 1) + + self.ks_list.sort() + self.expand_ratio_list.sort() + self.depth_list.sort() + + if base_stage_width == "google": + # MobileNetV2 Stage Width + base_stage_width = [32, 16, 24, 32, 64, 96, 160, 320, 1280] + else: + # ProxylessNAS Stage Width + base_stage_width = [32, 16, 24, 40, 80, 96, 192, 320, 1280] + + input_channel = make_divisible(base_stage_width[0] * self.width_mult) + first_block_width = make_divisible(base_stage_width[1] * self.width_mult) + last_channel = make_divisible(base_stage_width[-1] * self.width_mult) + + # first conv layer + first_conv = ConvLayer( + 3, + input_channel, + kernel_size=3, + stride=2, + use_bn=True, + act_func="relu6", + ops_order="weight_bn_act", + ) + # first block + first_block_conv = MBConvLayer( + in_channels=input_channel, + out_channels=first_block_width, + kernel_size=3, + stride=1, + expand_ratio=1, + act_func="relu6", + ) + first_block = ResidualBlock(first_block_conv, None) + + input_channel = first_block_width + # inverted residual blocks + self.block_group_info = [] + blocks = [first_block] + _block_index = 1 + + stride_stages = [2, 2, 2, 1, 2, 1] + n_block_list = [max(self.depth_list)] * 5 + [1] + + width_list = [] + for base_width in base_stage_width[2:-1]: + width = make_divisible(base_width * self.width_mult) + width_list.append(width) + + for width, n_block, s in zip(width_list, n_block_list, stride_stages): + self.block_group_info.append([_block_index + i for i in range(n_block)]) + _block_index += n_block + + output_channel = width + for i in range(n_block): + if i == 0: + stride = s + else: + stride = 1 + + mobile_inverted_conv = DynamicMBConvLayer( + in_channel_list=val2list(input_channel, 1), + out_channel_list=val2list(output_channel, 1), + kernel_size_list=ks_list, + expand_ratio_list=expand_ratio_list, + stride=stride, + act_func="relu6", + ) + + if stride == 1 and input_channel == output_channel: + shortcut = IdentityLayer(input_channel, input_channel) + else: + shortcut = None + + mb_inverted_block = ResidualBlock(mobile_inverted_conv, shortcut) + + blocks.append(mb_inverted_block) + input_channel = output_channel + # 1x1_conv before global average pooling + feature_mix_layer = ConvLayer( + input_channel, + last_channel, + kernel_size=1, + use_bn=True, + act_func="relu6", + ) + classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) + + super(OFAProxylessNASNet, self).__init__( + first_conv, blocks, feature_mix_layer, classifier + ) + + # set bn param + self.set_bn_param(momentum=bn_param[0], eps=bn_param[1]) + + # runtime_depth + self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info] + + """ WSConv_Network required methods """ + + @staticmethod + def name(): + return "OFAProxylessNASNet" + + def forward(self, x): + # first conv + x = self.first_conv(x) + # first block + x = self.blocks[0](x) + + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + x = self.blocks[idx](x) + + # feature_mix_layer + x = self.feature_mix_layer(x) + x = x.mean(3).mean(2) + + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = self.first_conv.module_str + "\n" + _str += self.blocks[0].module_str + "\n" + + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + for idx in active_idx: + _str += self.blocks[idx].module_str + "\n" + _str += self.feature_mix_layer.module_str + "\n" + _str += self.classifier.module_str + "\n" + return _str + + @property + def config(self): + return { + "name": OFAProxylessNASNet.__name__, + "bn": self.get_bn_param(), + "first_conv": self.first_conv.config, + "blocks": [block.config for block in self.blocks], + "feature_mix_layer": None + if self.feature_mix_layer is None + else self.feature_mix_layer.config, + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + raise ValueError("do not support this function") + + @property + def grouped_block_index(self): + return self.block_group_info + + def load_state_dict(self, state_dict, **kwargs): + model_dict = self.state_dict() + for key in state_dict: + if ".mobile_inverted_conv." in key: + new_key = key.replace(".mobile_inverted_conv.", ".conv.") + else: + new_key = key + if new_key in model_dict: + pass + elif ".bn.bn." in new_key: + new_key = new_key.replace(".bn.bn.", ".bn.") + elif ".conv.conv.weight" in new_key: + new_key = new_key.replace(".conv.conv.weight", ".conv.weight") + elif ".linear.linear." in new_key: + new_key = new_key.replace(".linear.linear.", ".linear.") + ############################################################################## + elif ".linear." in new_key: + new_key = new_key.replace(".linear.", ".linear.linear.") + elif "bn." in new_key: + new_key = new_key.replace("bn.", "bn.bn.") + elif "conv.weight" in new_key: + new_key = new_key.replace("conv.weight", "conv.conv.weight") + else: + raise ValueError(new_key) + assert new_key in model_dict, "%s" % new_key + model_dict[new_key] = state_dict[key] + super(OFAProxylessNASNet, self).load_state_dict(model_dict) + + """ set, sample and get active sub-networks """ + + def set_max_net(self): + self.set_active_subnet( + ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list) + ) + + def set_active_subnet(self, ks=None, e=None, d=None, **kwargs): + ks = val2list(ks, len(self.blocks) - 1) + expand_ratio = val2list(e, len(self.blocks) - 1) + depth = val2list(d, len(self.block_group_info)) + + for block, k, e in zip(self.blocks[1:], ks, expand_ratio): + if k is not None: + block.conv.active_kernel_size = k + if e is not None: + block.conv.active_expand_ratio = e + + for i, d in enumerate(depth): + if d is not None: + self.runtime_depth[i] = min(len(self.block_group_info[i]), d) + + def set_constraint(self, include_list, constraint_type="depth"): + if constraint_type == "depth": + self.__dict__["_depth_include_list"] = include_list.copy() + elif constraint_type == "expand_ratio": + self.__dict__["_expand_include_list"] = include_list.copy() + elif constraint_type == "kernel_size": + self.__dict__["_ks_include_list"] = include_list.copy() + else: + raise NotImplementedError + + def clear_constraint(self): + self.__dict__["_depth_include_list"] = None + self.__dict__["_expand_include_list"] = None + self.__dict__["_ks_include_list"] = None + + def sample_active_subnet(self): + ks_candidates = ( + self.ks_list + if self.__dict__.get("_ks_include_list", None) is None + else self.__dict__["_ks_include_list"] + ) + expand_candidates = ( + self.expand_ratio_list + if self.__dict__.get("_expand_include_list", None) is None + else self.__dict__["_expand_include_list"] + ) + depth_candidates = ( + self.depth_list + if self.__dict__.get("_depth_include_list", None) is None + else self.__dict__["_depth_include_list"] + ) + + # sample kernel size + ks_setting = [] + if not isinstance(ks_candidates[0], list): + ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)] + for k_set in ks_candidates: + k = random.choice(k_set) + ks_setting.append(k) + + # sample expand ratio + expand_setting = [] + if not isinstance(expand_candidates[0], list): + expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)] + for e_set in expand_candidates: + e = random.choice(e_set) + expand_setting.append(e) + + # sample depth + depth_setting = [] + if not isinstance(depth_candidates[0], list): + depth_candidates = [ + depth_candidates for _ in range(len(self.block_group_info)) + ] + for d_set in depth_candidates: + d = random.choice(d_set) + depth_setting.append(d) + + depth_setting[-1] = 1 + self.set_active_subnet(ks_setting, expand_setting, depth_setting) + + return { + "ks": ks_setting, + "e": expand_setting, + "d": depth_setting, + } + + def get_active_subnet(self, preserve_weight=True): + first_conv = deepcopy(self.first_conv) + blocks = [deepcopy(self.blocks[0])] + feature_mix_layer = deepcopy(self.feature_mix_layer) + classifier = deepcopy(self.classifier) + + input_channel = blocks[0].conv.out_channels + # blocks + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append( + ResidualBlock( + self.blocks[idx].conv.get_active_subnet( + input_channel, preserve_weight + ), + deepcopy(self.blocks[idx].shortcut), + ) + ) + input_channel = stage_blocks[-1].conv.out_channels + blocks += stage_blocks + + _subnet = ProxylessNASNet(first_conv, blocks, feature_mix_layer, classifier) + _subnet.set_bn_param(**self.get_bn_param()) + return _subnet + + def get_active_net_config(self): + first_conv_config = self.first_conv.config + first_block_config = self.blocks[0].config + feature_mix_layer_config = self.feature_mix_layer.config + classifier_config = self.classifier.config + + block_config_list = [first_block_config] + input_channel = first_block_config["conv"]["out_channels"] + for stage_id, block_idx in enumerate(self.block_group_info): + depth = self.runtime_depth[stage_id] + active_idx = block_idx[:depth] + stage_blocks = [] + for idx in active_idx: + stage_blocks.append( + { + "name": ResidualBlock.__name__, + "conv": self.blocks[idx].conv.get_active_subnet_config( + input_channel + ), + "shortcut": self.blocks[idx].shortcut.config + if self.blocks[idx].shortcut is not None + else None, + } + ) + try: + input_channel = self.blocks[idx].conv.active_out_channel + except Exception: + input_channel = self.blocks[idx].conv.out_channels + block_config_list += stage_blocks + + return { + "name": ProxylessNASNet.__name__, + "bn": self.get_bn_param(), + "first_conv": first_conv_config, + "blocks": block_config_list, + "feature_mix_layer": feature_mix_layer_config, + "classifier": classifier_config, + } + + """ Width Related Methods """ + + def re_organize_middle_weights(self, expand_ratio_stage=0): + for block in self.blocks[1:]: + block.conv.re_organize_middle_weights(expand_ratio_stage) + + +def _OFAProxylessNASNet(): + return OFAProxylessNASNet() diff --git a/xnas/spaces/OFA/ResNets/cnn.py b/xnas/spaces/OFA/ResNets/cnn.py new file mode 100644 index 0000000..b1b9ea8 --- /dev/null +++ b/xnas/spaces/OFA/ResNets/cnn.py @@ -0,0 +1,248 @@ +import torch.nn as nn + +from xnas.spaces.OFA.utils import make_divisible +from xnas.spaces.OFA.ops import ( + set_layer_from_config, + ConvLayer, + IdentityLayer, + LinearLayer, + ResidualBlock, + ResNetBottleneckBlock, + GlobalAvgPool2d +) +from xnas.spaces.OFA.MobileNetV3.cnn import WSConv_Network + + +__all__ = ["ResNet", "ResNet50", "ResNet50D"] + + +class ResNet(WSConv_Network): + + BASE_DEPTH_LIST = [2, 2, 4, 2] + STAGE_WIDTH_LIST = [256, 512, 1024, 2048] + + def __init__(self, input_stem, blocks, classifier): + super(ResNet, self).__init__() + + self.input_stem = nn.ModuleList(input_stem) + self.max_pooling = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False + ) + self.blocks = nn.ModuleList(blocks) + self.global_avg_pool = GlobalAvgPool2d(keep_dim=False) + self.classifier = classifier + + def forward(self, x): + for layer in self.input_stem: + x = layer(x) + x = self.max_pooling(x) + for block in self.blocks: + x = block(x) + x = self.global_avg_pool(x) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = "" + for layer in self.input_stem: + _str += layer.module_str + "\n" + _str += "max_pooling(ks=3, stride=2)\n" + for block in self.blocks: + _str += block.module_str + "\n" + _str += self.global_avg_pool.__repr__() + "\n" + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + "name": ResNet.__name__, + "bn": self.get_bn_param(), + "input_stem": [layer.config for layer in self.input_stem], + "blocks": [block.config for block in self.blocks], + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + classifier = set_layer_from_config(config["classifier"]) + + input_stem = [] + for layer_config in config["input_stem"]: + input_stem.append(set_layer_from_config(layer_config)) + blocks = [] + for block_config in config["blocks"]: + blocks.append(set_layer_from_config(block_config)) + + net = ResNet(input_stem, blocks, classifier) + if "bn" in config: + net.set_bn_param(**config["bn"]) + else: + net.set_bn_param(momentum=0.1, eps=1e-5) + + return net + + def zero_last_gamma(self): + for m in self.modules(): + if isinstance(m, ResNetBottleneckBlock) and isinstance( + m.downsample, IdentityLayer + ): + m.conv3.bn.weight.data.zero_() + + @property + def grouped_block_index(self): + info_list = [] + block_index_list = [] + for i, block in enumerate(self.blocks): + if ( + not isinstance(block.downsample, IdentityLayer) + and len(block_index_list) > 0 + ): + info_list.append(block_index_list) + block_index_list = [] + block_index_list.append(i) + if len(block_index_list) > 0: + info_list.append(block_index_list) + return info_list + + def load_state_dict(self, state_dict, **kwargs): + super(ResNet, self).load_state_dict(state_dict) + + +class ResNet50(ResNet): + def __init__( + self, + n_classes=1000, + width_mult=1.0, + bn_param=(0.1, 1e-5), + dropout_rate=0, + expand_ratio=None, + depth_param=None, + ): + + expand_ratio = 0.25 if expand_ratio is None else expand_ratio + + input_channel = make_divisible(64 * width_mult) + stage_width_list = ResNet.STAGE_WIDTH_LIST.copy() + for i, width in enumerate(stage_width_list): + stage_width_list[i] = make_divisible(width * width_mult) + + depth_list = [3, 4, 6, 3] + if depth_param is not None: + for i, depth in enumerate(ResNet.BASE_DEPTH_LIST): + depth_list[i] = depth + depth_param + + stride_list = [1, 2, 2, 2] + + # build input stem + input_stem = [ + ConvLayer( + 3, + input_channel, + kernel_size=7, + stride=2, + use_bn=True, + act_func="relu", + ops_order="weight_bn_act", + ) + ] + + # blocks + blocks = [] + for d, width, s in zip(depth_list, stage_width_list, stride_list): + for i in range(d): + stride = s if i == 0 else 1 + bottleneck_block = ResNetBottleneckBlock( + input_channel, + width, + kernel_size=3, + stride=stride, + expand_ratio=expand_ratio, + act_func="relu", + downsample_mode="conv", + ) + blocks.append(bottleneck_block) + input_channel = width + # classifier + classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate) + + super(ResNet50, self).__init__(input_stem, blocks, classifier) + + # set bn param + self.set_bn_param(*bn_param) + + +class ResNet50D(ResNet): + def __init__( + self, + n_classes=1000, + width_mult=1.0, + bn_param=(0.1, 1e-5), + dropout_rate=0, + expand_ratio=None, + depth_param=None, + ): + + expand_ratio = 0.25 if expand_ratio is None else expand_ratio + + input_channel = make_divisible(64 * width_mult) + mid_input_channel = make_divisible(input_channel // 2) + stage_width_list = ResNet.STAGE_WIDTH_LIST.copy() + for i, width in enumerate(stage_width_list): + stage_width_list[i] = make_divisible(width * width_mult) + + depth_list = [3, 4, 6, 3] + if depth_param is not None: + for i, depth in enumerate(ResNet.BASE_DEPTH_LIST): + depth_list[i] = depth + depth_param + + stride_list = [1, 2, 2, 2] + + # build input stem + input_stem = [ + ConvLayer(3, mid_input_channel, 3, stride=2, use_bn=True, act_func="relu"), + ResidualBlock( + ConvLayer( + mid_input_channel, + mid_input_channel, + 3, + stride=1, + use_bn=True, + act_func="relu", + ), + IdentityLayer(mid_input_channel, mid_input_channel), + ), + ConvLayer( + mid_input_channel, + input_channel, + 3, + stride=1, + use_bn=True, + act_func="relu", + ), + ] + + # blocks + blocks = [] + for d, width, s in zip(depth_list, stage_width_list, stride_list): + for i in range(d): + stride = s if i == 0 else 1 + bottleneck_block = ResNetBottleneckBlock( + input_channel, + width, + kernel_size=3, + stride=stride, + expand_ratio=expand_ratio, + act_func="relu", + downsample_mode="avgpool_conv", + ) + blocks.append(bottleneck_block) + input_channel = width + # classifier + classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate) + + super(ResNet50D, self).__init__(input_stem, blocks, classifier) + + # set bn param + self.set_bn_param(*bn_param) diff --git a/xnas/spaces/OFA/ResNets/ofa_cnn.py b/xnas/spaces/OFA/ResNets/ofa_cnn.py new file mode 100644 index 0000000..84949af --- /dev/null +++ b/xnas/spaces/OFA/ResNets/ofa_cnn.py @@ -0,0 +1,353 @@ +import random + +from xnas.spaces.OFA.utils import val2list, make_divisible +from xnas.spaces.OFA.dynamic_ops import ( + DynamicConvLayer, + DynamicLinearLayer, + DynamicResNetBottleneckBlock, +) +from xnas.spaces.OFA.ops import ( + IdentityLayer, + ResidualBlock, +) +from xnas.spaces.OFA.ResNets.cnn import ResNet + + +__all__ = ["_OFAResNet", "OFAResNet"] + + +class OFAResNet(ResNet): + def __init__( + self, + n_classes=1000, + bn_param=(0.1, 1e-5), + dropout_rate=0, + depth_list=2, + expand_ratio_list=0.25, + width_mult_list=1.0, + ): + + self.depth_list = val2list(depth_list) + self.expand_ratio_list = val2list(expand_ratio_list) + self.width_mult_list = val2list(width_mult_list) + # sort + self.depth_list.sort() + self.expand_ratio_list.sort() + self.width_mult_list.sort() + + input_channel = [ + make_divisible(64 * width_mult) + for width_mult in self.width_mult_list + ] + mid_input_channel = [ + make_divisible(channel // 2) + for channel in input_channel + ] + + stage_width_list = ResNet.STAGE_WIDTH_LIST.copy() + for i, width in enumerate(stage_width_list): + stage_width_list[i] = [ + make_divisible(width * width_mult) + for width_mult in self.width_mult_list + ] + + n_block_list = [ + base_depth + max(self.depth_list) for base_depth in ResNet.BASE_DEPTH_LIST + ] + stride_list = [1, 2, 2, 2] + + # build input stem + input_stem = [ + DynamicConvLayer( + val2list(3), + mid_input_channel, + 3, + stride=2, + use_bn=True, + act_func="relu", + ), + ResidualBlock( + DynamicConvLayer( + mid_input_channel, + mid_input_channel, + 3, + stride=1, + use_bn=True, + act_func="relu", + ), + IdentityLayer(mid_input_channel, mid_input_channel), + ), + DynamicConvLayer( + mid_input_channel, + input_channel, + 3, + stride=1, + use_bn=True, + act_func="relu", + ), + ] + + # blocks + blocks = [] + for d, width, s in zip(n_block_list, stage_width_list, stride_list): + for i in range(d): + stride = s if i == 0 else 1 + bottleneck_block = DynamicResNetBottleneckBlock( + input_channel, + width, + expand_ratio_list=self.expand_ratio_list, + kernel_size=3, + stride=stride, + act_func="relu", + downsample_mode="avgpool_conv", + ) + blocks.append(bottleneck_block) + input_channel = width + # classifier + classifier = DynamicLinearLayer( + input_channel, n_classes, dropout_rate=dropout_rate + ) + + super(OFAResNet, self).__init__(input_stem, blocks, classifier) + + # set bn param + self.set_bn_param(*bn_param) + + # runtime_depth + self.input_stem_skipping = 0 + self.runtime_depth = [0] * len(n_block_list) + + @property + def ks_list(self): + return [3] + + @staticmethod + def name(): + return "OFAResNet" + + def forward(self, x): + for layer in self.input_stem: + if ( + self.input_stem_skipping > 0 + and isinstance(layer, ResidualBlock) + and isinstance(layer.shortcut, IdentityLayer) + ): + pass + else: + x = layer(x) + x = self.max_pooling(x) + for stage_id, block_idx in enumerate(self.grouped_block_index): + depth_param = self.runtime_depth[stage_id] + active_idx = block_idx[: len(block_idx) - depth_param] + for idx in active_idx: + x = self.blocks[idx](x) + x = self.global_avg_pool(x) + x = self.classifier(x) + return x + + @property + def module_str(self): + _str = "" + for layer in self.input_stem: + if ( + self.input_stem_skipping > 0 + and isinstance(layer, ResidualBlock) + and isinstance(layer.shortcut, IdentityLayer) + ): + pass + else: + _str += layer.module_str + "\n" + _str += "max_pooling(ks=3, stride=2)\n" + for stage_id, block_idx in enumerate(self.grouped_block_index): + depth_param = self.runtime_depth[stage_id] + active_idx = block_idx[: len(block_idx) - depth_param] + for idx in active_idx: + _str += self.blocks[idx].module_str + "\n" + _str += self.global_avg_pool.__repr__() + "\n" + _str += self.classifier.module_str + return _str + + @property + def config(self): + return { + "name": OFAResNet.__name__, + "bn": self.get_bn_param(), + "input_stem": [layer.config for layer in self.input_stem], + "blocks": [block.config for block in self.blocks], + "classifier": self.classifier.config, + } + + @staticmethod + def build_from_config(config): + raise ValueError("do not support this function") + + def load_state_dict(self, state_dict, **kwargs): + model_dict = self.state_dict() + for key in state_dict: + new_key = key + if new_key in model_dict: + pass + elif ".linear." in new_key: + new_key = new_key.replace(".linear.", ".linear.linear.") + elif "bn." in new_key: + new_key = new_key.replace("bn.", "bn.bn.") + elif "conv.weight" in new_key: + new_key = new_key.replace("conv.weight", "conv.conv.weight") + else: + raise ValueError(new_key) + assert new_key in model_dict, "%s" % new_key + model_dict[new_key] = state_dict[key] + super(OFAResNet, self).load_state_dict(model_dict) + + """ set, sample and get active sub-networks """ + + def set_max_net(self): + self.set_active_subnet( + d=max(self.depth_list), + e=max(self.expand_ratio_list), + w=len(self.width_mult_list) - 1, + ) + + def set_active_subnet(self, d=None, e=None, w=None, **kwargs): + depth = val2list(d, len(ResNet.BASE_DEPTH_LIST) + 1) + expand_ratio = val2list(e, len(self.blocks)) + width_mult = val2list(w, len(ResNet.BASE_DEPTH_LIST) + 2) + + for block, e in zip(self.blocks, expand_ratio): + if e is not None: + block.active_expand_ratio = e + + if width_mult[0] is not None: + self.input_stem[1].conv.active_out_channel = self.input_stem[ + 0 + ].active_out_channel = self.input_stem[0].out_channel_list[width_mult[0]] + if width_mult[1] is not None: + self.input_stem[2].active_out_channel = self.input_stem[2].out_channel_list[ + width_mult[1] + ] + + if depth[0] is not None: + self.input_stem_skipping = depth[0] != max(self.depth_list) + for stage_id, (block_idx, d, w) in enumerate( + zip(self.grouped_block_index, depth[1:], width_mult[2:]) + ): + if d is not None: + self.runtime_depth[stage_id] = max(self.depth_list) - d + if w is not None: + for idx in block_idx: + self.blocks[idx].active_out_channel = self.blocks[ + idx + ].out_channel_list[w] + + def sample_active_subnet(self): + # sample expand ratio + expand_setting = [] + for block in self.blocks: + expand_setting.append(random.choice(block.expand_ratio_list)) + + # sample depth + depth_setting = [random.choice([max(self.depth_list), min(self.depth_list)])] + for stage_id in range(len(ResNet.BASE_DEPTH_LIST)): + depth_setting.append(random.choice(self.depth_list)) + + # sample width_mult + width_mult_setting = [ + random.choice(list(range(len(self.input_stem[0].out_channel_list)))), + random.choice(list(range(len(self.input_stem[2].out_channel_list)))), + ] + for stage_id, block_idx in enumerate(self.grouped_block_index): + stage_first_block = self.blocks[block_idx[0]] + width_mult_setting.append( + random.choice(list(range(len(stage_first_block.out_channel_list)))) + ) + + arch_config = {"d": depth_setting, "e": expand_setting, "w": width_mult_setting} + self.set_active_subnet(**arch_config) + return arch_config + + def get_active_subnet(self, preserve_weight=True): + input_stem = [self.input_stem[0].get_active_subnet(3, preserve_weight)] + if self.input_stem_skipping <= 0: + input_stem.append( + ResidualBlock( + self.input_stem[1].conv.get_active_subnet( + self.input_stem[0].active_out_channel, preserve_weight + ), + IdentityLayer( + self.input_stem[0].active_out_channel, + self.input_stem[0].active_out_channel, + ), + ) + ) + input_stem.append( + self.input_stem[2].get_active_subnet( + self.input_stem[0].active_out_channel, preserve_weight + ) + ) + input_channel = self.input_stem[2].active_out_channel + + blocks = [] + for stage_id, block_idx in enumerate(self.grouped_block_index): + depth_param = self.runtime_depth[stage_id] + active_idx = block_idx[: len(block_idx) - depth_param] + for idx in active_idx: + blocks.append( + self.blocks[idx].get_active_subnet(input_channel, preserve_weight) + ) + input_channel = self.blocks[idx].active_out_channel + classifier = self.classifier.get_active_subnet(input_channel, preserve_weight) + subnet = ResNet(input_stem, blocks, classifier) + + subnet.set_bn_param(**self.get_bn_param()) + return subnet + + def get_active_net_config(self): + input_stem_config = [self.input_stem[0].get_active_subnet_config(3)] + if self.input_stem_skipping <= 0: + input_stem_config.append( + { + "name": ResidualBlock.__name__, + "conv": self.input_stem[1].conv.get_active_subnet_config( + self.input_stem[0].active_out_channel + ), + "shortcut": IdentityLayer( + self.input_stem[0].active_out_channel, + self.input_stem[0].active_out_channel, + ), + } + ) + input_stem_config.append( + self.input_stem[2].get_active_subnet_config( + self.input_stem[0].active_out_channel + ) + ) + input_channel = self.input_stem[2].active_out_channel + + blocks_config = [] + for stage_id, block_idx in enumerate(self.grouped_block_index): + depth_param = self.runtime_depth[stage_id] + active_idx = block_idx[: len(block_idx) - depth_param] + for idx in active_idx: + blocks_config.append( + self.blocks[idx].get_active_subnet_config(input_channel) + ) + input_channel = self.blocks[idx].active_out_channel + classifier_config = self.classifier.get_active_subnet_config(input_channel) + return { + "name": ResNet.__name__, + "bn": self.get_bn_param(), + "input_stem": input_stem_config, + "blocks": blocks_config, + "classifier": classifier_config, + } + + """ Width Related Methods """ + + def re_organize_middle_weights(self, expand_ratio_stage=0): + for block in self.blocks: + block.re_organize_middle_weights(expand_ratio_stage) + + +def _OFAResNet(): + return OFAResNet() + diff --git a/xnas/spaces/OFA/dynamic_ops.py b/xnas/spaces/OFA/dynamic_ops.py new file mode 100644 index 0000000..f150dc2 --- /dev/null +++ b/xnas/spaces/OFA/dynamic_ops.py @@ -0,0 +1,1182 @@ +from copy import deepcopy +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from xnas.spaces.OFA.utils import ( + get_same_padding, + val2list, + make_divisible, +) +from xnas.spaces.OFA.ops import ( + build_activation, + set_layer_from_config, + SEModule, + WeightStandardConv2d, + MBConvLayer, + ConvLayer, + IdentityLayer, + ResNetBottleneckBlock, + LinearLayer, +) + + +class DynamicSeparableConv2d(nn.Module): + KERNEL_TRANSFORM_MODE = 1 # None or 1 + + def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1): + super(DynamicSeparableConv2d, self).__init__() + + self.max_in_channels = max_in_channels + self.kernel_size_list = kernel_size_list + self.stride = stride + self.dilation = dilation + + self.conv = nn.Conv2d( + self.max_in_channels, + self.max_in_channels, + max(self.kernel_size_list), + self.stride, + groups=self.max_in_channels, + bias=False, + ) + + self._ks_set = list(set(self.kernel_size_list)) + self._ks_set.sort() # e.g., [3, 5, 7] + if self.KERNEL_TRANSFORM_MODE is not None: + # register scaling parameters + # 7to5_matrix, 5to3_matrix + scale_params = {} + for i in range(len(self._ks_set) - 1): + ks_small = self._ks_set[i] + ks_larger = self._ks_set[i + 1] + param_name = "%dto%d" % (ks_larger, ks_small) + # noinspection PyArgumentList + scale_params["%s_matrix" % param_name] = Parameter( + torch.eye(ks_small ** 2) + ) + for name, param in scale_params.items(): + self.register_parameter(name, param) + + self.active_kernel_size = max(self.kernel_size_list) + + def get_active_filter(self, in_channel, kernel_size): + out_channel = in_channel + max_kernel_size = max(self.kernel_size_list) + + start, end = sub_filter_start_end(max_kernel_size, kernel_size) + filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end] + if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size: + start_filter = self.conv.weight[ + :out_channel, :in_channel, :, : + ] # start with max kernel + for i in range(len(self._ks_set) - 1, 0, -1): + src_ks = self._ks_set[i] + if src_ks <= kernel_size: + break + target_ks = self._ks_set[i - 1] + start, end = sub_filter_start_end(src_ks, target_ks) + _input_filter = start_filter[:, :, start:end, start:end] + _input_filter = _input_filter.contiguous() + _input_filter = _input_filter.view( + _input_filter.size(0), _input_filter.size(1), -1 + ) + _input_filter = _input_filter.view(-1, _input_filter.size(2)) + _input_filter = F.linear( + _input_filter, + self.__getattr__("%dto%d_matrix" % (src_ks, target_ks)), + ) + _input_filter = _input_filter.view( + filters.size(0), filters.size(1), target_ks ** 2 + ) + _input_filter = _input_filter.view( + filters.size(0), filters.size(1), target_ks, target_ks + ) + start_filter = _input_filter + filters = start_filter + return filters + + def forward(self, x, kernel_size=None): + if kernel_size is None: + kernel_size = self.active_kernel_size + in_channel = x.size(1) + + filters = self.get_active_filter(in_channel, kernel_size).contiguous() + + padding = get_same_padding(kernel_size) + filters = ( + self.conv.weight_standardization(filters) + if isinstance(self.conv, WeightStandardConv2d) + else filters + ) + y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, in_channel) + return y + + +class DynamicConv2d(nn.Module): + def __init__( + self, max_in_channels, max_out_channels, kernel_size=1, stride=1, dilation=1 + ): + super(DynamicConv2d, self).__init__() + + self.max_in_channels = max_in_channels + self.max_out_channels = max_out_channels + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + self.conv = nn.Conv2d( + self.max_in_channels, + self.max_out_channels, + self.kernel_size, + stride=self.stride, + bias=False, + ) + + self.active_out_channel = self.max_out_channels + + def get_active_filter(self, out_channel, in_channel): + return self.conv.weight[:out_channel, :in_channel, :, :] + + def forward(self, x, out_channel=None): + if out_channel is None: + out_channel = self.active_out_channel + in_channel = x.size(1) + filters = self.get_active_filter(out_channel, in_channel).contiguous() + + padding = get_same_padding(self.kernel_size) + filters = ( + self.conv.weight_standardization(filters) + if isinstance(self.conv, WeightStandardConv2d) + else filters + ) + y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, 1) + return y + + +class DynamicGroupConv2d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size_list, + groups_list, + stride=1, + dilation=1, + ): + super(DynamicGroupConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size_list = kernel_size_list + self.groups_list = groups_list + self.stride = stride + self.dilation = dilation + + self.conv = nn.Conv2d( + self.in_channels, + self.out_channels, + max(self.kernel_size_list), + self.stride, + groups=min(self.groups_list), + bias=False, + ) + + self.active_kernel_size = max(self.kernel_size_list) + self.active_groups = min(self.groups_list) + + def get_active_filter(self, kernel_size, groups): + start, end = sub_filter_start_end(max(self.kernel_size_list), kernel_size) + filters = self.conv.weight[:, :, start:end, start:end] + + sub_filters = torch.chunk(filters, groups, dim=0) + sub_in_channels = self.in_channels // groups + sub_ratio = filters.size(1) // sub_in_channels + + filter_crops = [] + for i, sub_filter in enumerate(sub_filters): + part_id = i % sub_ratio + start = part_id * sub_in_channels + filter_crops.append(sub_filter[:, start : start + sub_in_channels, :, :]) + filters = torch.cat(filter_crops, dim=0) + return filters + + def forward(self, x, kernel_size=None, groups=None): + if kernel_size is None: + kernel_size = self.active_kernel_size + if groups is None: + groups = self.active_groups + + filters = self.get_active_filter(kernel_size, groups).contiguous() + padding = get_same_padding(kernel_size) + filters = ( + self.conv.weight_standardization(filters) + if isinstance(self.conv, WeightStandardConv2d) + else filters + ) + y = F.conv2d( + x, + filters, + None, + self.stride, + padding, + self.dilation, + groups, + ) + return y + + +class DynamicBatchNorm2d(nn.Module): + SET_RUNNING_STATISTICS = False + + def __init__(self, max_feature_dim): + super(DynamicBatchNorm2d, self).__init__() + + self.max_feature_dim = max_feature_dim + self.bn = nn.BatchNorm2d(self.max_feature_dim) + + @staticmethod + def bn_forward(x, bn: nn.BatchNorm2d, feature_dim): + if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS: + return bn(x) + else: + exponential_average_factor = 0.0 + + if bn.training and bn.track_running_stats: + if bn.num_batches_tracked is not None: + bn.num_batches_tracked += 1 + if bn.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(bn.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = bn.momentum + return F.batch_norm( + x, + bn.running_mean[:feature_dim], + bn.running_var[:feature_dim], + bn.weight[:feature_dim], + bn.bias[:feature_dim], + bn.training or not bn.track_running_stats, + exponential_average_factor, + bn.eps, + ) + + def forward(self, x): + feature_dim = x.size(1) + y = self.bn_forward(x, self.bn, feature_dim) + return y + + +class DynamicGroupNorm(nn.GroupNorm): + def __init__( + self, num_groups, num_channels, eps=1e-5, affine=True, channel_per_group=None + ): + super(DynamicGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + self.channel_per_group = channel_per_group + + def forward(self, x): + n_channels = x.size(1) + n_groups = n_channels // self.channel_per_group + return F.group_norm( + x, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps + ) + + @property + def bn(self): + return self + + +class DynamicSE(SEModule): + def __init__(self, max_channel): + super(DynamicSE, self).__init__(max_channel) + + def get_active_reduce_weight(self, num_mid, in_channel, groups=None): + if groups is None or groups == 1: + return self.fc.reduce.weight[:num_mid, :in_channel, :, :] + else: + assert in_channel % groups == 0 + sub_in_channels = in_channel // groups + sub_filters = torch.chunk( + self.fc.reduce.weight[:num_mid, :, :, :], groups, dim=1 + ) + return torch.cat( + [sub_filter[:, :sub_in_channels, :, :] for sub_filter in sub_filters], + dim=1, + ) + + def get_active_reduce_bias(self, num_mid): + return ( + self.fc.reduce.bias[:num_mid] if self.fc.reduce.bias is not None else None + ) + + def get_active_expand_weight(self, num_mid, in_channel, groups=None): + if groups is None or groups == 1: + return self.fc.expand.weight[:in_channel, :num_mid, :, :] + else: + assert in_channel % groups == 0 + sub_in_channels = in_channel // groups + sub_filters = torch.chunk( + self.fc.expand.weight[:, :num_mid, :, :], groups, dim=0 + ) + return torch.cat( + [sub_filter[:sub_in_channels, :, :, :] for sub_filter in sub_filters], + dim=0, + ) + + def get_active_expand_bias(self, in_channel, groups=None): + if groups is None or groups == 1: + return ( + self.fc.expand.bias[:in_channel] + if self.fc.expand.bias is not None + else None + ) + else: + assert in_channel % groups == 0 + sub_in_channels = in_channel // groups + sub_bias_list = torch.chunk(self.fc.expand.bias, groups, dim=0) + return torch.cat( + [sub_bias[:sub_in_channels] for sub_bias in sub_bias_list], dim=0 + ) + + def forward(self, x, groups=None): + in_channel = x.size(1) + num_mid = make_divisible(in_channel // self.reduction) + + y = x.mean(3, keepdim=True).mean(2, keepdim=True) + # reduce + reduce_filter = self.get_active_reduce_weight( + num_mid, in_channel, groups=groups + ).contiguous() + reduce_bias = self.get_active_reduce_bias(num_mid) + y = F.conv2d(y, reduce_filter, reduce_bias, 1, 0, 1, 1) + # relu + y = self.fc.relu(y) + # expand + expand_filter = self.get_active_expand_weight( + num_mid, in_channel, groups=groups + ).contiguous() + expand_bias = self.get_active_expand_bias(in_channel, groups=groups) + y = F.conv2d(y, expand_filter, expand_bias, 1, 0, 1, 1) + # hard sigmoid + y = self.fc.h_sigmoid(y) + + return x * y + + +class DynamicLinear(nn.Module): + def __init__(self, max_in_features, max_out_features, bias=True): + super(DynamicLinear, self).__init__() + + self.max_in_features = max_in_features + self.max_out_features = max_out_features + self.bias = bias + + self.linear = nn.Linear(self.max_in_features, self.max_out_features, self.bias) + + self.active_out_features = self.max_out_features + + def get_active_weight(self, out_features, in_features): + return self.linear.weight[:out_features, :in_features] + + def get_active_bias(self, out_features): + return self.linear.bias[:out_features] if self.bias else None + + def forward(self, x, out_features=None): + if out_features is None: + out_features = self.active_out_features + + in_features = x.size(1) + weight = self.get_active_weight(out_features, in_features).contiguous() + bias = self.get_active_bias(out_features) + y = F.linear(x, weight, bias) + return y + + +class DynamicLinearLayer(nn.Module): + def __init__(self, in_features_list, out_features, bias=True, dropout_rate=0): + super(DynamicLinearLayer, self).__init__() + + self.in_features_list = in_features_list + self.out_features = out_features + self.bias = bias + self.dropout_rate = dropout_rate + + if self.dropout_rate > 0: + self.dropout = nn.Dropout(self.dropout_rate, inplace=True) + else: + self.dropout = None + self.linear = DynamicLinear( + max_in_features=max(self.in_features_list), + max_out_features=self.out_features, + bias=self.bias, + ) + + def forward(self, x): + if self.dropout is not None: + x = self.dropout(x) + return self.linear(x) + + @property + def module_str(self): + return "DyLinear(%d, %d)" % (max(self.in_features_list), self.out_features) + + @property + def config(self): + return { + "name": DynamicLinear.__name__, + "in_features_list": self.in_features_list, + "out_features": self.out_features, + "bias": self.bias, + "dropout_rate": self.dropout_rate, + } + + @staticmethod + def build_from_config(config): + return DynamicLinearLayer(**config) + + def get_active_subnet(self, in_features, preserve_weight=True): + sub_layer = LinearLayer( + in_features, self.out_features, self.bias, dropout_rate=self.dropout_rate + ) + sub_layer = sub_layer.to(self.parameters().__next__().device) + if not preserve_weight: + return sub_layer + + sub_layer.linear.weight.data.copy_( + self.linear.get_active_weight(self.out_features, in_features).data + ) + if self.bias: + sub_layer.linear.bias.data.copy_( + self.linear.get_active_bias(self.out_features).data + ) + return sub_layer + + def get_active_subnet_config(self, in_features): + return { + "name": LinearLayer.__name__, + "in_features": in_features, + "out_features": self.out_features, + "bias": self.bias, + "dropout_rate": self.dropout_rate, + } + + +class DynamicMBConvLayer(nn.Module): + def __init__( + self, + in_channel_list, + out_channel_list, + kernel_size_list=3, + expand_ratio_list=6, + stride=1, + act_func="relu6", + use_se=False, + ): + super(DynamicMBConvLayer, self).__init__() + + self.in_channel_list = in_channel_list + self.out_channel_list = out_channel_list + + self.kernel_size_list = val2list(kernel_size_list) + self.expand_ratio_list = val2list(expand_ratio_list) + + self.stride = stride + self.act_func = act_func + self.use_se = use_se + + # build modules + max_middle_channel = make_divisible( + round(max(self.in_channel_list) * max(self.expand_ratio_list))) + if max(self.expand_ratio_list) == 1: + self.inverted_bottleneck = None + else: + self.inverted_bottleneck = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicConv2d( + max(self.in_channel_list), max_middle_channel + ), + ), + ("bn", DynamicBatchNorm2d(max_middle_channel)), + ("act", build_activation(self.act_func)), + ] + ) + ) + + self.depth_conv = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicSeparableConv2d( + max_middle_channel, self.kernel_size_list, self.stride + ), + ), + ("bn", DynamicBatchNorm2d(max_middle_channel)), + ("act", build_activation(self.act_func)), + ] + ) + ) + if self.use_se: + self.depth_conv.add_module("se", DynamicSE(max_middle_channel)) + + self.point_linear = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicConv2d(max_middle_channel, max(self.out_channel_list)), + ), + ("bn", DynamicBatchNorm2d(max(self.out_channel_list))), + ] + ) + ) + + self.active_kernel_size = max(self.kernel_size_list) + self.active_expand_ratio = max(self.expand_ratio_list) + self.active_out_channel = max(self.out_channel_list) + + def forward(self, x): + in_channel = x.size(1) + + if self.inverted_bottleneck is not None: + self.inverted_bottleneck.conv.active_out_channel = make_divisible( + round(in_channel * self.active_expand_ratio)) + + self.depth_conv.conv.active_kernel_size = self.active_kernel_size + self.point_linear.conv.active_out_channel = self.active_out_channel + + if self.inverted_bottleneck is not None: + x = self.inverted_bottleneck(x) + x = self.depth_conv(x) + x = self.point_linear(x) + return x + + @property + def module_str(self): + if self.use_se: + return "SE(O%d, E%.1f, K%d)" % ( + self.active_out_channel, + self.active_expand_ratio, + self.active_kernel_size, + ) + else: + return "(O%d, E%.1f, K%d)" % ( + self.active_out_channel, + self.active_expand_ratio, + self.active_kernel_size, + ) + + @property + def config(self): + return { + "name": DynamicMBConvLayer.__name__, + "in_channel_list": self.in_channel_list, + "out_channel_list": self.out_channel_list, + "kernel_size_list": self.kernel_size_list, + "expand_ratio_list": self.expand_ratio_list, + "stride": self.stride, + "act_func": self.act_func, + "use_se": self.use_se, + } + + @staticmethod + def build_from_config(config): + return DynamicMBConvLayer(**config) + + ############################################################################################ + + @property + def in_channels(self): + return max(self.in_channel_list) + + @property + def out_channels(self): + return max(self.out_channel_list) + + def active_middle_channel(self, in_channel): + return make_divisible(round(in_channel * self.active_expand_ratio)) + + ############################################################################################ + + def get_active_subnet(self, in_channel, preserve_weight=True): + # build the new layer + sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel)) + sub_layer = sub_layer.to(self.parameters().__next__().device) + if not preserve_weight: + return sub_layer + + middle_channel = self.active_middle_channel(in_channel) + # copy weight from current layer + if sub_layer.inverted_bottleneck is not None: + sub_layer.inverted_bottleneck.conv.weight.data.copy_( + self.inverted_bottleneck.conv.get_active_filter( + middle_channel, in_channel + ).data, + ) + copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn) + + sub_layer.depth_conv.conv.weight.data.copy_( + self.depth_conv.conv.get_active_filter( + middle_channel, self.active_kernel_size + ).data + ) + copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn) + + if self.use_se: + se_mid = make_divisible( + middle_channel // SEModule.REDUCTION, + ) + sub_layer.depth_conv.se.fc.reduce.weight.data.copy_( + self.depth_conv.se.get_active_reduce_weight(se_mid, middle_channel).data + ) + sub_layer.depth_conv.se.fc.reduce.bias.data.copy_( + self.depth_conv.se.get_active_reduce_bias(se_mid).data + ) + + sub_layer.depth_conv.se.fc.expand.weight.data.copy_( + self.depth_conv.se.get_active_expand_weight(se_mid, middle_channel).data + ) + sub_layer.depth_conv.se.fc.expand.bias.data.copy_( + self.depth_conv.se.get_active_expand_bias(middle_channel).data + ) + + sub_layer.point_linear.conv.weight.data.copy_( + self.point_linear.conv.get_active_filter( + self.active_out_channel, middle_channel + ).data + ) + copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn) + + return sub_layer + + def get_active_subnet_config(self, in_channel): + return { + "name": MBConvLayer.__name__, + "in_channels": in_channel, + "out_channels": self.active_out_channel, + "kernel_size": self.active_kernel_size, + "stride": self.stride, + "expand_ratio": self.active_expand_ratio, + "mid_channels": self.active_middle_channel(in_channel), + "act_func": self.act_func, + "use_se": self.use_se, + } + + def re_organize_middle_weights(self, expand_ratio_stage=0): + importance = torch.sum( + torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3) + ) + if isinstance(self.depth_conv.bn, DynamicGroupNorm): + channel_per_group = self.depth_conv.bn.channel_per_group + importance_chunks = torch.split(importance, channel_per_group) + for chunk in importance_chunks: + chunk.data.fill_(torch.mean(chunk)) + importance = torch.cat(importance_chunks, dim=0) + if expand_ratio_stage > 0: + sorted_expand_list = deepcopy(self.expand_ratio_list) + sorted_expand_list.sort(reverse=True) + target_width_list = [ + make_divisible(round(max(self.in_channel_list) * expand)) + for expand in sorted_expand_list + ] + + right = len(importance) + base = -len(target_width_list) * 1e5 + for i in range(expand_ratio_stage + 1): + left = target_width_list[i] + importance[left:right] += base + base += 1e5 + right = left + + sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True) + self.point_linear.conv.conv.weight.data = torch.index_select( + self.point_linear.conv.conv.weight.data, 1, sorted_idx + ) + + adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx) + self.depth_conv.conv.conv.weight.data = torch.index_select( + self.depth_conv.conv.conv.weight.data, 0, sorted_idx + ) + + if self.use_se: + # se expand: output dim 0 reorganize + se_expand = self.depth_conv.se.fc.expand + se_expand.weight.data = torch.index_select( + se_expand.weight.data, 0, sorted_idx + ) + se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx) + # se reduce: input dim 1 reorganize + se_reduce = self.depth_conv.se.fc.reduce + se_reduce.weight.data = torch.index_select( + se_reduce.weight.data, 1, sorted_idx + ) + # middle weight reorganize + se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3)) + se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True) + + se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx) + se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx) + se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx) + + if self.inverted_bottleneck is not None: + adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx) + self.inverted_bottleneck.conv.conv.weight.data = torch.index_select( + self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx + ) + return None + else: + return sorted_idx + + +class DynamicConvLayer(nn.Module): + def __init__( + self, + in_channel_list, + out_channel_list, + kernel_size=3, + stride=1, + dilation=1, + use_bn=True, + act_func="relu6", + ): + super(DynamicConvLayer, self).__init__() + + self.in_channel_list = in_channel_list + self.out_channel_list = out_channel_list + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.use_bn = use_bn + self.act_func = act_func + + self.conv = DynamicConv2d( + max_in_channels=max(self.in_channel_list), + max_out_channels=max(self.out_channel_list), + kernel_size=self.kernel_size, + stride=self.stride, + dilation=self.dilation, + ) + if self.use_bn: + self.bn = DynamicBatchNorm2d(max(self.out_channel_list)) + self.act = build_activation(self.act_func) + + self.active_out_channel = max(self.out_channel_list) + + def forward(self, x): + self.conv.active_out_channel = self.active_out_channel + + x = self.conv(x) + if self.use_bn: + x = self.bn(x) + x = self.act(x) + return x + + @property + def module_str(self): + return "DyConv(O%d, K%d, S%d)" % ( + self.active_out_channel, + self.kernel_size, + self.stride, + ) + + @property + def config(self): + return { + "name": DynamicConvLayer.__name__, + "in_channel_list": self.in_channel_list, + "out_channel_list": self.out_channel_list, + "kernel_size": self.kernel_size, + "stride": self.stride, + "dilation": self.dilation, + "use_bn": self.use_bn, + "act_func": self.act_func, + } + + @staticmethod + def build_from_config(config): + return DynamicConvLayer(**config) + + ############################################################################################ + + @property + def in_channels(self): + return max(self.in_channel_list) + + @property + def out_channels(self): + return max(self.out_channel_list) + + ############################################################################################ + + def get_active_subnet(self, in_channel, preserve_weight=True): + sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel)) + sub_layer = sub_layer.to(self.parameters().__next__().device) + + if not preserve_weight: + return sub_layer + + sub_layer.conv.weight.data.copy_( + self.conv.get_active_filter(self.active_out_channel, in_channel).data + ) + if self.use_bn: + copy_bn(sub_layer.bn, self.bn.bn) + + return sub_layer + + def get_active_subnet_config(self, in_channel): + return { + "name": ConvLayer.__name__, + "in_channels": in_channel, + "out_channels": self.active_out_channel, + "kernel_size": self.kernel_size, + "stride": self.stride, + "dilation": self.dilation, + "use_bn": self.use_bn, + "act_func": self.act_func, + } + + +class DynamicResNetBottleneckBlock(nn.Module): + def __init__( + self, + in_channel_list, + out_channel_list, + expand_ratio_list=0.25, + kernel_size=3, + stride=1, + act_func="relu", + downsample_mode="avgpool_conv", + ): + super(DynamicResNetBottleneckBlock, self).__init__() + + self.in_channel_list = in_channel_list + self.out_channel_list = out_channel_list + self.expand_ratio_list = val2list(expand_ratio_list) + + self.kernel_size = kernel_size + self.stride = stride + self.act_func = act_func + self.downsample_mode = downsample_mode + + # build modules + max_middle_channel = make_divisible( + round(max(self.out_channel_list) * max(self.expand_ratio_list))) + + self.conv1 = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicConv2d(max(self.in_channel_list), max_middle_channel), + ), + ("bn", DynamicBatchNorm2d(max_middle_channel)), + ("act", build_activation(self.act_func, inplace=True)), + ] + ) + ) + + self.conv2 = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicConv2d( + max_middle_channel, max_middle_channel, kernel_size, stride + ), + ), + ("bn", DynamicBatchNorm2d(max_middle_channel)), + ("act", build_activation(self.act_func, inplace=True)), + ] + ) + ) + + self.conv3 = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicConv2d(max_middle_channel, max(self.out_channel_list)), + ), + ("bn", DynamicBatchNorm2d(max(self.out_channel_list))), + ] + ) + ) + + if self.stride == 1 and self.in_channel_list == self.out_channel_list: + self.downsample = IdentityLayer( + max(self.in_channel_list), max(self.out_channel_list) + ) + elif self.downsample_mode == "conv": + self.downsample = nn.Sequential( + OrderedDict( + [ + ( + "conv", + DynamicConv2d( + max(self.in_channel_list), + max(self.out_channel_list), + stride=stride, + ), + ), + ("bn", DynamicBatchNorm2d(max(self.out_channel_list))), + ] + ) + ) + elif self.downsample_mode == "avgpool_conv": + self.downsample = nn.Sequential( + OrderedDict( + [ + ( + "avg_pool", + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + padding=0, + ceil_mode=True, + ), + ), + ( + "conv", + DynamicConv2d( + max(self.in_channel_list), max(self.out_channel_list) + ), + ), + ("bn", DynamicBatchNorm2d(max(self.out_channel_list))), + ] + ) + ) + else: + raise NotImplementedError + + self.final_act = build_activation(self.act_func, inplace=True) + + self.active_expand_ratio = max(self.expand_ratio_list) + self.active_out_channel = max(self.out_channel_list) + + def forward(self, x): + feature_dim = self.active_middle_channels + + self.conv1.conv.active_out_channel = feature_dim + self.conv2.conv.active_out_channel = feature_dim + self.conv3.conv.active_out_channel = self.active_out_channel + if not isinstance(self.downsample, IdentityLayer): + self.downsample.conv.active_out_channel = self.active_out_channel + + residual = self.downsample(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + + x = x + residual + x = self.final_act(x) + return x + + @property + def module_str(self): + return "(%s, %s)" % ( + "%dx%d_BottleneckConv_in->%d->%d_S%d" + % ( + self.kernel_size, + self.kernel_size, + self.active_middle_channels, + self.active_out_channel, + self.stride, + ), + "Identity" + if isinstance(self.downsample, IdentityLayer) + else self.downsample_mode, + ) + + @property + def config(self): + return { + "name": DynamicResNetBottleneckBlock.__name__, + "in_channel_list": self.in_channel_list, + "out_channel_list": self.out_channel_list, + "expand_ratio_list": self.expand_ratio_list, + "kernel_size": self.kernel_size, + "stride": self.stride, + "act_func": self.act_func, + "downsample_mode": self.downsample_mode, + } + + @staticmethod + def build_from_config(config): + return DynamicResNetBottleneckBlock(**config) + + ############################################################################################ + + @property + def in_channels(self): + return max(self.in_channel_list) + + @property + def out_channels(self): + return max(self.out_channel_list) + + @property + def active_middle_channels(self): + feature_dim = round(self.active_out_channel * self.active_expand_ratio) + feature_dim = make_divisible(feature_dim) + return feature_dim + + ############################################################################################ + + def get_active_subnet(self, in_channel, preserve_weight=True): + # build the new layer + sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel)) + sub_layer = sub_layer.to(self.parameters().__next__().device) + if not preserve_weight: + return sub_layer + + # copy weight from current layer + sub_layer.conv1.conv.weight.data.copy_( + self.conv1.conv.get_active_filter( + self.active_middle_channels, in_channel + ).data + ) + copy_bn(sub_layer.conv1.bn, self.conv1.bn.bn) + + sub_layer.conv2.conv.weight.data.copy_( + self.conv2.conv.get_active_filter( + self.active_middle_channels, self.active_middle_channels + ).data + ) + copy_bn(sub_layer.conv2.bn, self.conv2.bn.bn) + + sub_layer.conv3.conv.weight.data.copy_( + self.conv3.conv.get_active_filter( + self.active_out_channel, self.active_middle_channels + ).data + ) + copy_bn(sub_layer.conv3.bn, self.conv3.bn.bn) + + if not isinstance(self.downsample, IdentityLayer): + sub_layer.downsample.conv.weight.data.copy_( + self.downsample.conv.get_active_filter( + self.active_out_channel, in_channel + ).data + ) + copy_bn(sub_layer.downsample.bn, self.downsample.bn.bn) + + return sub_layer + + def get_active_subnet_config(self, in_channel): + return { + "name": ResNetBottleneckBlock.__name__, + "in_channels": in_channel, + "out_channels": self.active_out_channel, + "kernel_size": self.kernel_size, + "stride": self.stride, + "expand_ratio": self.active_expand_ratio, + "mid_channels": self.active_middle_channels, + "act_func": self.act_func, + "groups": 1, + "downsample_mode": self.downsample_mode, + } + + def re_organize_middle_weights(self, expand_ratio_stage=0): + # conv3 -> conv2 + importance = torch.sum( + torch.abs(self.conv3.conv.conv.weight.data), dim=(0, 2, 3) + ) + if isinstance(self.conv2.bn, DynamicGroupNorm): + channel_per_group = self.conv2.bn.channel_per_group + importance_chunks = torch.split(importance, channel_per_group) + for chunk in importance_chunks: + chunk.data.fill_(torch.mean(chunk)) + importance = torch.cat(importance_chunks, dim=0) + if expand_ratio_stage > 0: + sorted_expand_list = deepcopy(self.expand_ratio_list) + sorted_expand_list.sort(reverse=True) + target_width_list = [ + make_divisible(round(max(self.out_channel_list) * expand)) + for expand in sorted_expand_list + ] + right = len(importance) + base = -len(target_width_list) * 1e5 + for i in range(expand_ratio_stage + 1): + left = target_width_list[i] + importance[left:right] += base + base += 1e5 + right = left + + sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True) + self.conv3.conv.conv.weight.data = torch.index_select( + self.conv3.conv.conv.weight.data, 1, sorted_idx + ) + adjust_bn_according_to_idx(self.conv2.bn.bn, sorted_idx) + self.conv2.conv.conv.weight.data = torch.index_select( + self.conv2.conv.conv.weight.data, 0, sorted_idx + ) + + # conv2 -> conv1 + importance = torch.sum( + torch.abs(self.conv2.conv.conv.weight.data), dim=(0, 2, 3) + ) + if isinstance(self.conv1.bn, DynamicGroupNorm): + channel_per_group = self.conv1.bn.channel_per_group + importance_chunks = torch.split(importance, channel_per_group) + for chunk in importance_chunks: + chunk.data.fill_(torch.mean(chunk)) + importance = torch.cat(importance_chunks, dim=0) + if expand_ratio_stage > 0: + sorted_expand_list = deepcopy(self.expand_ratio_list) + sorted_expand_list.sort(reverse=True) + target_width_list = [ + make_divisible(round(max(self.out_channel_list) * expand)) + for expand in sorted_expand_list + ] + right = len(importance) + base = -len(target_width_list) * 1e5 + for i in range(expand_ratio_stage + 1): + left = target_width_list[i] + importance[left:right] += base + base += 1e5 + right = left + sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True) + + self.conv2.conv.conv.weight.data = torch.index_select( + self.conv2.conv.conv.weight.data, 1, sorted_idx + ) + adjust_bn_according_to_idx(self.conv1.bn.bn, sorted_idx) + self.conv1.conv.conv.weight.data = torch.index_select( + self.conv1.conv.conv.weight.data, 0, sorted_idx + ) + + return None + + +def sub_filter_start_end(kernel_size, sub_kernel_size): + center = kernel_size // 2 + dev = sub_kernel_size // 2 + start, end = center - dev, center + dev + 1 + assert end - start == sub_kernel_size + return start, end + + +def adjust_bn_according_to_idx(bn, idx): + bn.weight.data = torch.index_select(bn.weight.data, 0, idx) + bn.bias.data = torch.index_select(bn.bias.data, 0, idx) + if type(bn) in [nn.BatchNorm1d, nn.BatchNorm2d]: + bn.running_mean.data = torch.index_select(bn.running_mean.data, 0, idx) + bn.running_var.data = torch.index_select(bn.running_var.data, 0, idx) + + +def copy_bn(target_bn, src_bn): + feature_dim = ( + target_bn.num_channels + if isinstance(target_bn, nn.GroupNorm) + else target_bn.num_features + ) + + target_bn.weight.data.copy_(src_bn.weight.data[:feature_dim]) + target_bn.bias.data.copy_(src_bn.bias.data[:feature_dim]) + if type(src_bn) in [nn.BatchNorm1d, nn.BatchNorm2d]: + target_bn.running_mean.data.copy_(src_bn.running_mean.data[:feature_dim]) + target_bn.running_var.data.copy_(src_bn.running_var.data[:feature_dim]) diff --git a/xnas/spaces/OFA/ops.py b/xnas/spaces/OFA/ops.py new file mode 100644 index 0000000..ca6fbc8 --- /dev/null +++ b/xnas/spaces/OFA/ops.py @@ -0,0 +1,957 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict + +from xnas.spaces.OFA.utils import ( + min_divisible_value, + get_same_padding, + make_divisible, +) + + +def set_layer_from_config(layer_config): + if layer_config is None: + return None + + name2layer = { + ConvLayer.__name__: ConvLayer, + IdentityLayer.__name__: IdentityLayer, + LinearLayer.__name__: LinearLayer, + MultiHeadLinearLayer.__name__: MultiHeadLinearLayer, + ZeroLayer.__name__: ZeroLayer, + MBConvLayer.__name__: MBConvLayer, + "MBInvertedConvLayer": MBConvLayer, + ResidualBlock.__name__: ResidualBlock, + ResNetBottleneckBlock.__name__: ResNetBottleneckBlock, + } + + layer_name = layer_config.pop("name") + layer = name2layer[layer_name] + return layer.build_from_config(layer_config) + + +"""Activation.""" + +def build_activation(act_func, inplace=True): + if act_func == "relu": + return nn.ReLU(inplace=inplace) + elif act_func == "relu6": + return nn.ReLU6(inplace=inplace) + elif act_func == "tanh": + return nn.Tanh() + elif act_func == "sigmoid": + return nn.Sigmoid() + elif act_func == "h_swish": + return Hswish(inplace=inplace) + elif act_func == "h_sigmoid": + return Hsigmoid(inplace=inplace) + elif act_func is None or act_func == "none": + return None + else: + raise ValueError("do not support: %s" % act_func) + + +class Hswish(nn.Module): + def __init__(self, inplace=True): + super(Hswish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 + + def __repr__(self): + return "Hswish()" + + +class Hsigmoid(nn.Module): + def __init__(self, inplace=True): + super(Hsigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return F.relu6(x + 3.0, inplace=self.inplace) / 6.0 + + def __repr__(self): + return "Hsigmoid()" + + + +class ShuffleLayer(nn.Module): + def __init__(self, groups): + super(ShuffleLayer, self).__init__() + self.groups = groups + + def forward(self, x): + batch_size, num_channels, height, width = x.size() + channels_per_group = num_channels // self.groups + # reshape + x = x.view(batch_size, self.groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + # flatten + x = x.view(batch_size, -1, height, width) + return x + + def __repr__(self): + return "ShuffleLayer(groups=%d)" % self.groups + + +class GlobalAvgPool2d(nn.Module): + def __init__(self, keep_dim=True): + super(GlobalAvgPool2d, self).__init__() + self.keep_dim = keep_dim + + def forward(self, x): + return x.mean(3, keepdim=self.keep_dim).mean(2, keepdim=self.keep_dim) + + def __repr__(self): + return "GlobalAvgPool2d(keep_dim=%s)" % self.keep_dim + + +class SEModule(nn.Module): + REDUCTION = 4 + + def __init__(self, channel, reduction=None): + super(SEModule, self).__init__() + + self.channel = channel + self.reduction = SEModule.REDUCTION if reduction is None else reduction + + num_mid = make_divisible(self.channel // self.reduction) + + self.fc = nn.Sequential( + OrderedDict( + [ + ("reduce", nn.Conv2d(self.channel, num_mid, 1, 1, 0, bias=True)), + ("relu", nn.ReLU(inplace=True)), + ("expand", nn.Conv2d(num_mid, self.channel, 1, 1, 0, bias=True)), + ("h_sigmoid", Hsigmoid(inplace=True)), + ] + ) + ) + + def forward(self, x): + y = x.mean(3, keepdim=True).mean(2, keepdim=True) + y = self.fc(y) + return x * y + + def __repr__(self): + return "SE(channel=%d, reduction=%d)" % (self.channel, self.reduction) + + +class WeightStandardConv2d(nn.Conv2d): + """ + Conv2d with Weight Standardization + https://github.com/joe-siyuan-qiao/WeightStandardization + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + super(WeightStandardConv2d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + self.WS_EPS = None + + def weight_standardization(self, weight): + if self.WS_EPS is not None: + weight_mean = ( + weight.mean(dim=1, keepdim=True) + .mean(dim=2, keepdim=True) + .mean(dim=3, keepdim=True) + ) + weight = weight - weight_mean + std = ( + weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + + self.WS_EPS + ) + weight = weight / std.expand_as(weight) + return weight + + def forward(self, x): + if self.WS_EPS is None: + return super(WeightStandardConv2d, self).forward(x) + else: + return F.conv2d( + x, + self.weight_standardization(self.weight), + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def __repr__(self): + return super(WeightStandardConv2d, self).__repr__()[:-1] + ", ws_eps=%s)" % self.WS_EPS + + + +# Basic layer to init with Dropout, Conv, BN and activations (order not fixed.) +class MixedDCBRLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + use_bn=True, + act_func="relu", + dropout_rate=0, + ops_order="weight_bn_act", + ): + super(MixedDCBRLayer, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.use_bn = use_bn + self.act_func = act_func + self.dropout_rate = dropout_rate + self.ops_order = ops_order + + modules = {} + # batch norm + if self.use_bn: + if self.bn_before_weight: + modules["bn"] = nn.BatchNorm2d(in_channels) + else: + modules["bn"] = nn.BatchNorm2d(out_channels) + else: + modules["bn"] = None + # activation + modules["act"] = build_activation( + self.act_func, self.ops_list[0] != "act" and self.use_bn + ) + # dropout + if self.dropout_rate > 0: + modules["dropout"] = nn.Dropout2d(self.dropout_rate, inplace=True) + else: + modules["dropout"] = None + # weight + modules["weight"] = self.weight_op() + + # add modules + for op in self.ops_list: + if modules[op] is None: + continue + elif op == "weight": + # dropout before weight operation + if modules["dropout"] is not None: + self.add_module("dropout", modules["dropout"]) + for key in modules["weight"]: + self.add_module(key, modules["weight"][key]) + else: + self.add_module(op, modules[op]) + + @property + def ops_list(self): + return self.ops_order.split("_") + + @property + def bn_before_weight(self): + for op in self.ops_list: + if op == "bn": + return True + elif op == "weight": + return False + raise ValueError("Invalid ops_order: %s" % self.ops_order) + + def forward(self, x): + # similar to nn.Sequential + for module in self._modules.values(): + x = module(x) + return x + + @property + def config(self): + return { + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "use_bn": self.use_bn, + "act_func": self.act_func, + "dropout_rate": self.dropout_rate, + "ops_order": self.ops_order, + } + + +class ConvLayer(MixedDCBRLayer): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + dilation=1, + groups=1, + bias=False, + has_shuffle=False, + use_se=False, + use_bn=True, + act_func="relu", + dropout_rate=0, + ops_order="weight_bn_act", + ): + # default normal 3x3_Conv with bn and relu + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.groups = groups + self.bias = bias + self.has_shuffle = has_shuffle + self.use_se = use_se + + super(ConvLayer, self).__init__( + in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order + ) + if self.use_se: + self.add_module("se", SEModule(self.out_channels)) + + def weight_op(self): + padding = get_same_padding(self.kernel_size) + if isinstance(padding, int): + padding *= self.dilation + else: + padding[0] *= self.dilation + padding[1] *= self.dilation + + weight_dict = OrderedDict( + { + "conv": nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=padding, + dilation=self.dilation, + groups=min_divisible_value(self.in_channels, self.groups), + bias=self.bias, + ) + } + ) + if self.has_shuffle and self.groups > 1: + weight_dict["shuffle"] = ShuffleLayer(self.groups) + + return weight_dict + + @property + def module_str(self): + if isinstance(self.kernel_size, int): + kernel_size = (self.kernel_size, self.kernel_size) + else: + kernel_size = self.kernel_size + if self.groups == 1: + if self.dilation > 1: + conv_str = "%dx%d_DilatedConv" % (kernel_size[0], kernel_size[1]) + else: + conv_str = "%dx%d_Conv" % (kernel_size[0], kernel_size[1]) + else: + if self.dilation > 1: + conv_str = "%dx%d_DilatedGroupConv" % (kernel_size[0], kernel_size[1]) + else: + conv_str = "%dx%d_GroupConv" % (kernel_size[0], kernel_size[1]) + conv_str += "_O%d" % self.out_channels + if self.use_se: + conv_str = "SE_" + conv_str + conv_str += "_" + self.act_func.upper() + if self.use_bn: + if isinstance(self.bn, nn.GroupNorm): + conv_str += "_GN%d" % self.bn.num_groups + elif isinstance(self.bn, nn.BatchNorm2d): + conv_str += "_BN" + return conv_str + + @property + def config(self): + return { + "name": ConvLayer.__name__, + "kernel_size": self.kernel_size, + "stride": self.stride, + "dilation": self.dilation, + "groups": self.groups, + "bias": self.bias, + "has_shuffle": self.has_shuffle, + "use_se": self.use_se, + **super(ConvLayer, self).config, + } + + +class IdentityLayer(MixedDCBRLayer): + def __init__( + self, + in_channels, + out_channels, + use_bn=False, + act_func=None, + dropout_rate=0, + ops_order="weight_bn_act", + ): + super(IdentityLayer, self).__init__( + in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order + ) + + def weight_op(self): + return None + + @property + def module_str(self): + return "Identity" + + @property + def config(self): + return { + "name": IdentityLayer.__name__, + **super(IdentityLayer, self).config, + } + + @staticmethod + def build_from_config(config): + return IdentityLayer(**config) + + +class LinearLayer(nn.Module): + def __init__( + self, + in_features, + out_features, + bias=True, + use_bn=False, + act_func=None, + dropout_rate=0, + ops_order="weight_bn_act", + ): + super(LinearLayer, self).__init__() + + self.in_features = in_features + self.out_features = out_features + self.bias = bias + + self.use_bn = use_bn + self.act_func = act_func + self.dropout_rate = dropout_rate + self.ops_order = ops_order + + """ modules """ + modules = {} + # batch norm + if self.use_bn: + if self.bn_before_weight: + modules["bn"] = nn.BatchNorm1d(in_features) + else: + modules["bn"] = nn.BatchNorm1d(out_features) + else: + modules["bn"] = None + # activation + modules["act"] = build_activation(self.act_func, self.ops_list[0] != "act") + # dropout + if self.dropout_rate > 0: + modules["dropout"] = nn.Dropout(self.dropout_rate, inplace=True) + else: + modules["dropout"] = None + # linear + modules["weight"] = { + "linear": nn.Linear(self.in_features, self.out_features, self.bias) + } + + # add modules + for op in self.ops_list: + if modules[op] is None: + continue + elif op == "weight": + if modules["dropout"] is not None: + self.add_module("dropout", modules["dropout"]) + for key in modules["weight"]: + self.add_module(key, modules["weight"][key]) + else: + self.add_module(op, modules[op]) + + @property + def ops_list(self): + return self.ops_order.split("_") + + @property + def bn_before_weight(self): + for op in self.ops_list: + if op == "bn": + return True + elif op == "weight": + return False + raise ValueError("Invalid ops_order: %s" % self.ops_order) + + def forward(self, x): + for module in self._modules.values(): + x = module(x) + return x + + @property + def module_str(self): + return "%dx%d_Linear" % (self.in_features, self.out_features) + + @property + def config(self): + return { + "name": LinearLayer.__name__, + "in_features": self.in_features, + "out_features": self.out_features, + "bias": self.bias, + "use_bn": self.use_bn, + "act_func": self.act_func, + "dropout_rate": self.dropout_rate, + "ops_order": self.ops_order, + } + + @staticmethod + def build_from_config(config): + return LinearLayer(**config) + + +class MultiHeadLinearLayer(nn.Module): + def __init__( + self, in_features, out_features, num_heads=1, bias=True, dropout_rate=0 + ): + super(MultiHeadLinearLayer, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.num_heads = num_heads + + self.bias = bias + self.dropout_rate = dropout_rate + + if self.dropout_rate > 0: + self.dropout = nn.Dropout(self.dropout_rate, inplace=True) + else: + self.dropout = None + + self.layers = nn.ModuleList() + for k in range(num_heads): + layer = nn.Linear(in_features, out_features, self.bias) + self.layers.append(layer) + + def forward(self, inputs): + if self.dropout is not None: + inputs = self.dropout(inputs) + + outputs = [] + for layer in self.layers: + output = layer.forward(inputs) + outputs.append(output) + + outputs = torch.stack(outputs, dim=1) + return outputs + + @property + def module_str(self): + return self.__repr__() + + @property + def config(self): + return { + "name": MultiHeadLinearLayer.__name__, + "in_features": self.in_features, + "out_features": self.out_features, + "num_heads": self.num_heads, + "bias": self.bias, + "dropout_rate": self.dropout_rate, + } + + @staticmethod + def build_from_config(config): + return MultiHeadLinearLayer(**config) + + def __repr__(self): + return ( + "MultiHeadLinear(in_features=%d, out_features=%d, num_heads=%d, bias=%s, dropout_rate=%s)" % ( + self.in_features, + self.out_features, + self.num_heads, + self.bias, + self.dropout_rate, + )) + + +class ZeroLayer(nn.Module): + def __init__(self): + super(ZeroLayer, self).__init__() + + def forward(self, x): + raise ValueError + + @property + def module_str(self): + return "Zero" + + @property + def config(self): + return { + "name": ZeroLayer.__name__, + } + + @staticmethod + def build_from_config(config): + return ZeroLayer() + + +class MBConvLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + expand_ratio=6, + mid_channels=None, + act_func="relu6", + use_se=False, + groups=None, + ): + super(MBConvLayer, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + self.kernel_size = kernel_size + self.stride = stride + self.expand_ratio = expand_ratio + self.mid_channels = mid_channels + self.act_func = act_func + self.use_se = use_se + self.groups = groups + + if self.mid_channels is None: + feature_dim = round(self.in_channels * self.expand_ratio) + else: + feature_dim = self.mid_channels + + if self.expand_ratio == 1: + self.inverted_bottleneck = None + else: + self.inverted_bottleneck = nn.Sequential( + OrderedDict( + [ + ( + "conv", + nn.Conv2d( + self.in_channels, feature_dim, 1, 1, 0, bias=False + ), + ), + ("bn", nn.BatchNorm2d(feature_dim)), + ("act", build_activation(self.act_func, inplace=True)), + ] + ) + ) + + pad = get_same_padding(self.kernel_size) + groups = ( + feature_dim + if self.groups is None + else min_divisible_value(feature_dim, self.groups) + ) + depth_conv_modules = [ + ( + "conv", + nn.Conv2d( + feature_dim, + feature_dim, + kernel_size, + stride, + pad, + groups=groups, + bias=False, + ), + ), + ("bn", nn.BatchNorm2d(feature_dim)), + ("act", build_activation(self.act_func, inplace=True)), + ] + if self.use_se: + depth_conv_modules.append(("se", SEModule(feature_dim))) + self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules)) + + self.point_linear = nn.Sequential( + OrderedDict( + [ + ("conv", nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)), + ("bn", nn.BatchNorm2d(out_channels)), + ] + ) + ) + + def forward(self, x): + if self.inverted_bottleneck: + x = self.inverted_bottleneck(x) + x = self.depth_conv(x) + x = self.point_linear(x) + return x + + @property + def module_str(self): + if self.mid_channels is None: + expand_ratio = self.expand_ratio + else: + expand_ratio = self.mid_channels // self.in_channels + layer_str = "%dx%d_MBConv%d_%s" % ( + self.kernel_size, + self.kernel_size, + expand_ratio, + self.act_func.upper(), + ) + if self.use_se: + layer_str = "SE_" + layer_str + layer_str += "_O%d" % self.out_channels + if self.groups is not None: + layer_str += "_G%d" % self.groups + if isinstance(self.point_linear.bn, nn.GroupNorm): + layer_str += "_GN%d" % self.point_linear.bn.num_groups + elif isinstance(self.point_linear.bn, nn.BatchNorm2d): + layer_str += "_BN" + + return layer_str + + @property + def config(self): + return { + "name": MBConvLayer.__name__, + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "kernel_size": self.kernel_size, + "stride": self.stride, + "expand_ratio": self.expand_ratio, + "mid_channels": self.mid_channels, + "act_func": self.act_func, + "use_se": self.use_se, + "groups": self.groups, + } + + @staticmethod + def build_from_config(config): + return MBConvLayer(**config) + + +class ResidualBlock(nn.Module): + def __init__(self, conv, shortcut): + super(ResidualBlock, self).__init__() + + self.conv = conv + self.shortcut = shortcut + + def forward(self, x): + if self.conv is None or isinstance(self.conv, ZeroLayer): + res = x + elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer): + res = self.conv(x) + else: + res = self.conv(x) + self.shortcut(x) + return res + + @property + def module_str(self): + return "(%s, %s)" % ( + self.conv.module_str if self.conv is not None else None, + self.shortcut.module_str if self.shortcut is not None else None, + ) + + @property + def config(self): + return { + "name": ResidualBlock.__name__, + "conv": self.conv.config if self.conv is not None else None, + "shortcut": self.shortcut.config if self.shortcut is not None else None, + } + + @staticmethod + def build_from_config(config): + conv_config = ( + config["conv"] if "conv" in config else config["mobile_inverted_conv"] + ) + conv = set_layer_from_config(conv_config) + shortcut = set_layer_from_config(config["shortcut"]) + return ResidualBlock(conv, shortcut) + + @property + def mobile_inverted_conv(self): + return self.conv + + +class ResNetBottleneckBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + expand_ratio=0.25, + mid_channels=None, + act_func="relu", + groups=1, + downsample_mode="avgpool_conv", + ): + super(ResNetBottleneckBlock, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + self.kernel_size = kernel_size + self.stride = stride + self.expand_ratio = expand_ratio + self.mid_channels = mid_channels + self.act_func = act_func + self.groups = groups + + self.downsample_mode = downsample_mode + + if self.mid_channels is None: + feature_dim = round(self.out_channels * self.expand_ratio) + else: + feature_dim = self.mid_channels + + feature_dim = make_divisible(feature_dim) + self.mid_channels = feature_dim + + # build modules + self.conv1 = nn.Sequential( + OrderedDict( + [ + ( + "conv", + nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False), + ), + ("bn", nn.BatchNorm2d(feature_dim)), + ("act", build_activation(self.act_func, inplace=True)), + ] + ) + ) + + pad = get_same_padding(self.kernel_size) + self.conv2 = nn.Sequential( + OrderedDict( + [ + ( + "conv", + nn.Conv2d( + feature_dim, + feature_dim, + kernel_size, + stride, + pad, + groups=groups, + bias=False, + ), + ), + ("bn", nn.BatchNorm2d(feature_dim)), + ("act", build_activation(self.act_func, inplace=True)), + ] + ) + ) + + self.conv3 = nn.Sequential( + OrderedDict( + [ + ( + "conv", + nn.Conv2d(feature_dim, self.out_channels, 1, 1, 0, bias=False), + ), + ("bn", nn.BatchNorm2d(self.out_channels)), + ] + ) + ) + + if stride == 1 and in_channels == out_channels: + self.downsample = IdentityLayer(in_channels, out_channels) + elif self.downsample_mode == "conv": + self.downsample = nn.Sequential( + OrderedDict( + [ + ( + "conv", + nn.Conv2d( + in_channels, out_channels, 1, stride, 0, bias=False + ), + ), + ("bn", nn.BatchNorm2d(out_channels)), + ] + ) + ) + elif self.downsample_mode == "avgpool_conv": + self.downsample = nn.Sequential( + OrderedDict( + [ + ( + "avg_pool", + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + padding=0, + ceil_mode=True, + ), + ), + ( + "conv", + nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False), + ), + ("bn", nn.BatchNorm2d(out_channels)), + ] + ) + ) + else: + raise NotImplementedError + + self.final_act = build_activation(self.act_func, inplace=True) + + def forward(self, x): + residual = self.downsample(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + + x = x + residual + x = self.final_act(x) + return x + + @property + def module_str(self): + return "(%s, %s)" % ( + "%dx%d_BottleneckConv_%d->%d->%d_S%d_G%d" + % ( + self.kernel_size, + self.kernel_size, + self.in_channels, + self.mid_channels, + self.out_channels, + self.stride, + self.groups, + ), + "Identity" + if isinstance(self.downsample, IdentityLayer) + else self.downsample_mode, + ) + + @property + def config(self): + return { + "name": ResNetBottleneckBlock.__name__, + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "kernel_size": self.kernel_size, + "stride": self.stride, + "expand_ratio": self.expand_ratio, + "mid_channels": self.mid_channels, + "act_func": self.act_func, + "groups": self.groups, + "downsample_mode": self.downsample_mode, + } + + @staticmethod + def build_from_config(config): + return ResNetBottleneckBlock(**config) diff --git a/xnas/spaces/OFA/utils.py b/xnas/spaces/OFA/utils.py new file mode 100644 index 0000000..1d830db --- /dev/null +++ b/xnas/spaces/OFA/utils.py @@ -0,0 +1,111 @@ +"""Utilities.""" + +import math +import numpy as np +import torch.nn as nn + + +def list_sum(x): + return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) + + +def list_mean(x): + return list_sum(x) / len(x) + + +def min_divisible_value(n1, v1): + """make sure v1 is divisible by n1, otherwise decrease v1""" + if v1 >= n1: + return n1 + while n1 % v1 != 0: + v1 -= 1 + return v1 + +def val2list(val, repeat_time=1): + if isinstance(val, list) or isinstance(val, np.ndarray): + return val + elif isinstance(val, tuple): + return list(val) + else: + return [val for _ in range(repeat_time)] + + +""" Layer releated""" + +def get_same_padding(kernel_size): + if isinstance(kernel_size, tuple): + assert len(kernel_size) == 2, "invalid kernel size: %s" % kernel_size + p1 = get_same_padding(kernel_size[0]) + p2 = get_same_padding(kernel_size[1]) + return p1, p2 + assert isinstance(kernel_size, int), "kernel size should be either `int` or `tuple`" + assert kernel_size % 2 > 0, "kernel size should be odd number" + return kernel_size // 2 + +def make_divisible(v, divisor=8, min_val=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_val: + :return: + """ + if min_val is None: + min_val = divisor + new_v = max(min_val, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +""" BN related """ + +def clean_num_batch_tracked(net): + for m in net.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + if m.num_batches_tracked is not None: + m.num_batches_tracked.zero_() + + +def rm_bn_from_net(net): + for m in net.modules(): + if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + m.forward = lambda x: x + + +""" network related """ + +def init_model(net, model_init="he_fout"): + """ + Conv2d, + BatchNorm2d, BatchNorm1d, GroupNorm + Linear, + """ + if isinstance(net, list): + for sub_net in net: + init_models(sub_net, model_init) + return + for m in net.modules(): + if isinstance(m, nn.Conv2d): + if model_init == "he_fout": + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + elif model_init == "he_fin": + n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + else: + raise NotImplementedError + if m.bias is not None: + m.bias.data.zero_() + elif type(m) in [nn.BatchNorm2d, nn.BatchNorm1d, nn.GroupNorm]: + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + stdv = 1.0 / math.sqrt(m.weight.size(1)) + m.weight.data.uniform_(-stdv, stdv) + if m.bias is not None: + m.bias.data.zero_() -- 2.34.1 From 7225f96ff294952fb19ed63777148af85f02eca9 Mon Sep 17 00:00:00 2001 From: xfey Date: Tue, 24 May 2022 16:51:35 +0800 Subject: [PATCH 5/5] fix typos in README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5daa51d..bc3b6af 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ We are gradually providing support for more settings. To run XNAS, `python>=3.7` and `pytorch=1.9` are required. Other versions of `PyTorch` may also work well, but there are potential API differences that can cause warnings to be generated. -For detailed instructions, please refer to [**Getting_started.md**](./docs/get_started.md) and [**Data_preparation.md**](./docs/data_preparation.md) in our docs. +For detailed instructions, please refer to [**get_started.md**](./docs/get_started.md) and [**data_preparation.md**](./docs/data_preparation.md) in our docs. ## Contributing -- 2.34.1