#1 yao

Merged
Lumen3ever merged 46 commits from yao into master 1 year ago
  1. +130
    -0
      configs/model_utils/config.py
  2. +115
    -0
      configs/model_utils/moxing_adapter.py
  3. +64
    -0
      configs/vit_patch32_imagenet2012_config_cloud.yml
  4. +81
    -0
      scripts/run_ViT_train_distribute.sh
  5. +241
    -0
      tools/train_ViT.py
  6. +108
    -0
      xbm/core/callback_ViT.py
  7. +124
    -0
      xbm/core/cross_entropy_ViT.py
  8. +105
    -0
      xbm/core/eval_engine_ViT.py
  9. +80
    -0
      xbm/core/logging_ViT.py
  10. +93
    -0
      xbm/core/lr_generator_ViT.py
  11. +0
    -40
      xbm/core/optimizer.py
  12. +214
    -0
      xbm/core/optimizer_ViT.py
  13. +264
    -0
      xbm/datasets/autoaugment_ViT.py
  14. +171
    -0
      xbm/datasets/dataset_ViT.py
  15. +506
    -0
      xbm/models/vit.py

+ 130
- 0
configs/model_utils/config.py View File

@@ -0,0 +1,130 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Parse arguments"""

import os
import ast
import argparse
from pprint import pprint, pformat
import yaml

class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)

def __str__(self):
return pformat(self.__dict__)

def __repr__(self):
return self.__str__()


def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.

Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args


def parse_yaml(yaml_path):
"""
Parse the yaml config file.

Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices


def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.

Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg


def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str,
default=os.path.join(current_dir, "../vit_patch32_imagenet2012_config_cloud.yml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
config_path = path_args.config_path
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=config_path)
final_config = merge(args, default)
pprint(final_config)
print("Please check the above information for the configurations", flush=True)
return Config(final_config)

config = get_config()

+ 115
- 0
configs/model_utils/moxing_adapter.py View File

@@ -0,0 +1,115 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Moxing adapter for ModelArts"""

import os
import functools
from mindspore import context
from .config import config

_global_sync_count = 0

def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)


def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)


def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)


def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id

def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1

# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")

while True:
if os.path.exists(sync_lock):
break
time.sleep(1)

print("Finish sync data from {} to {}.".format(from_path, to_path))


def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))

context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)

if pre_process:
pre_process()

run_func(*args, **kwargs)

# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()

if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

+ 64
- 0
configs/vit_patch32_imagenet2012_config_cloud.yml View File

@@ -0,0 +1,64 @@
enable_modelarts: 1

# Url for modelarts
data_url: "s3://zhengxiawu/data/ImageNet2012"
train_url: "s3://zhengxiawu/project/yao/output"
checkpoint_url: ""
output_path: "s3://zhengxiawu/project/yao/output"
data_path: "s3://zhengxiawu/data/ImageNet2012"
load_path: ""

# train datasets
dataset_path: 's3://zhengxiawu/data/ImageNet2012/train'
train_image_size: 224
interpolation: 'BILINEAR'
crop_min: 0.05
batch_size: 256
train_num_workers: 14

# eval datasets
eval_path: 's3://zhengxiawu/data/ImageNet2012/val'
eval_image_size: 224
eval_batch_size: 256
eval_interval: 1
eval_offset: -1
eval_num_workers: 12

# network
backbone: 'vit_base_patch32'
class_num: 1001
vit_config_path: 'src.vit.VitConfig'
pretrained: ''

# lr
lr_decay_mode: 'cosine'
lr_init: 0.0
lr_max: 0.00355
lr_min: 0.0
max_epoch: 300
warmup_epochs: 40

# optimizer
opt: 'adamw'
beta1: 0.9
beta2: 0.999
weight_decay: 0.05
no_weight_decay_filter: "beta,bias"
gc_flag: 0

# loss
loss_scale: 1024
use_label_smooth: 1
label_smooth_factor: 0.1
mixup: 0.2
autoaugment: 1
loss_name: "ce_smooth_mixup"

# ckpt
save_checkpoint: 1
save_checkpoint_epochs: 8
keep_checkpoint_max: 3
save_checkpoint_path: './outputs'

# profiler
open_profiler: 0

+ 81
- 0
scripts/run_ViT_train_distribute.sh View File

@@ -0,0 +1,81 @@
#!/bin/bash
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

if [ $# != 2 ]
then
echo "Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [CONFIG_PATH]"
exit 1
fi

get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}

PATH1=$(get_real_path $1)
CONFIG_FILE=$(get_real_path $2)

if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi

if [ ! -f $CONFIG_FILE ]
then
echo "error: config_path=$CONFIG_FILE is not a file"
exit 1
fi

ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1

export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))

cpus=`cat /proc/cpuinfo| grep "processor"| wc -l`
avg=`expr $cpus \/ $DEVICE_NUM`
gap=`expr $avg \- 1`

for((i=0; i<${DEVICE_NUM}; i++))
do
start=`expr $i \* $avg`
end=`expr $start \+ $gap`
cmdopt=$start"-"$end
export DEVICE_ID=${i}
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../config/*.yml ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log

if [ $# == 2 ]
then
taskset -c $cmdopt python train.py --config_path=$CONFIG_FILE &> log &
fi

cd ..
done

+ 241
- 0
tools/train_ViT.py View File

@@ -0,0 +1,241 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""training script"""

import os
import time
import socket
import numpy as np

from mindspore import context
from mindspore import Tensor
from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.communication.management import init
from mindspore.profiler.profiling import Profiler
from mindspore.train.serialization import load_checkpoint
import mindspore.dataset as ds

from xbm.models.vit import get_network
from xbm.datasets.dataset_ViT import get_dataset
from xbm.core.cross_entropy_ViT import get_loss
from xbm.core.optimizer_ViT import get_optimizer
from xbm.core.lr_generator_ViT import get_lr
from xbm.core.eval_engine_ViT import get_eval_engine
from xbm.core.callback_ViT import StateMonitor
from xbm.core.logging_ViT import get_logger

from configs.model_utils.config import config
from configs.model_utils.moxing_adapter import moxing_wrapper

try:
os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = os.getenv('RANK_TABLE_FILE')

device_id = int(os.getenv('DEVICE_ID')) # 0 ~ 7
local_rank = int(os.getenv('RANK_ID')) # local_rank
device_num = int(os.getenv('RANK_SIZE')) # world_size
print("distribute training")
except TypeError:
device_id = 0 # 0 ~ 7
local_rank = 0 # local_rank
device_num = 1 # world_size
print("standalone training")

def add_static_args(args):
"""add_static_args"""
args.weight_decay = float(args.weight_decay)

args.eval_engine = 'imagenet'
args.split_point = 0.4
args.poly_power = 2
args.aux_factor = 0.4
args.seed = 1
args.auto_tune = 0

if args.eval_offset < 0:
args.eval_offset = args.max_epoch % args.eval_interval

args.device_id = device_id
args.local_rank = local_rank
args.device_num = device_num
args.dataset_name = 'imagenet'

return args

def modelarts_pre_process():
'''modelarts pre process function.'''
start_t = time.time()

val_file = os.path.join(config.data_path, 'val/imagenet_val.tar')
train_file = os.path.join(config.data_path, 'train/imagenet_train.tar')
tar_files = [val_file, train_file]

print('tar_files:{}'.format(tar_files))
for tar_file in tar_files:
if os.path.exists(tar_file):
t1 = time.time()
tar_dir = os.path.dirname(tar_file)
print('cd {}; tar -xvf {} > /dev/null 2>&1'.format(tar_dir, tar_file))
os.system('cd {}; tar -xvf {} > /dev/null 2>&1'.format(tar_dir, tar_file))
t2 = time.time()
print('uncompress, time used={:.2f}s'.format(t2 - t1))
os.system('cd {}; rm -rf {}'.format(tar_dir, tar_file))
else:
print('file no exists:', tar_file)

end_t = time.time()
print('tar cost time {:.2f} sec'.format(end_t-start_t))


@moxing_wrapper(pre_process=modelarts_pre_process)
def train_net():
"""train_net"""
args = add_static_args(config)
np.random.seed(args.seed)
args.logger = get_logger(args.save_checkpoint_path, rank=local_rank)

context.set_context(device_id=device_id,
mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False)

if args.auto_tune:
context.set_context(auto_tune_mode='GA')
elif args.device_num == 1:
pass
else:
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)

if args.open_profiler:
profiler = Profiler(output_path="data_{}".format(local_rank))

# init the distribute env
if not args.auto_tune and args.device_num > 1:
init()

# network
net = get_network(backbone_name=args.backbone, args=args)

# set grad allreduce split point
parameters = [param for param in net.trainable_params()]
parameter_len = len(parameters)
if args.split_point > 0:
print("split_point={}".format(args.split_point))
split_parameter_index = [int(args.split_point*parameter_len),]
parameter_indices = 1
for i in range(parameter_len):
if i in split_parameter_index:
parameter_indices += 1
parameters[i].comm_fusion = parameter_indices
else:
print("warning!!!, no split point")

if os.path.isfile(args.pretrained):
load_checkpoint(args.pretrained, net, strict_load=False)

# loss
if not args.use_label_smooth:
args.label_smooth_factor = 0.0
loss = get_loss(loss_name=args.loss_name, args=args)

# train dataset
epoch_size = args.max_epoch
dataset = get_dataset(dataset_name=args.dataset_name,
do_train=True,
dataset_path=args.dataset_path,
args=args)
ds.config.set_seed(args.seed)
step_size = dataset.get_dataset_size()
args.steps_per_epoch = step_size

# evaluation dataset
eval_dataset = get_dataset(dataset_name=args.dataset_name,
do_train=False,
dataset_path=args.eval_path,
args=args)

# evaluation engine
if args.auto_tune or args.open_profiler or eval_dataset is None or args.device_num == 1:
args.eval_engine = ''
eval_engine = get_eval_engine(args.eval_engine, net, eval_dataset, args)

# loss scale
loss_scale = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)

# learning rate
lr_array = get_lr(global_step=0, lr_init=args.lr_init, lr_end=args.lr_min, lr_max=args.lr_max,
warmup_epochs=args.warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size,
lr_decay_mode=args.lr_decay_mode, poly_power=args.poly_power)
lr = Tensor(lr_array)

# optimizer, group_params used in grad freeze
opt, _ = get_optimizer(optimizer_name=args.opt,
network=net,
lrs=lr,
args=args)

# model
model = Model(net, loss_fn=loss, optimizer=opt,
metrics=eval_engine.metric, eval_network=eval_engine.eval_network,
loss_scale_manager=loss_scale, amp_level="O3")
eval_engine.set_model(model)
args.logger.save_args(args)

t0 = time.time()
# equal to model._init(dataset, sink_size=step_size)
eval_engine.compile(sink_size=step_size)

t1 = time.time()
args.logger.info('compile time used={:.2f}s'.format(t1 - t0))

# callbacks
state_cb = StateMonitor(data_size=step_size,
tot_batch_size=args.batch_size * device_num,
lrs=lr_array,
eval_interval=args.eval_interval,
eval_offset=args.eval_offset,
eval_engine=eval_engine,
logger=args.logger.info)

cb = [state_cb,]
if args.save_checkpoint and local_rank == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_epochs*step_size,
keep_checkpoint_max=args.keep_checkpoint_max,
async_save=True)
ckpt_cb = ModelCheckpoint(prefix=args.backbone, directory=args.save_checkpoint_path, config=config_ck)
cb += [ckpt_cb]

t0 = time.time()
model.train(epoch_size, dataset, callbacks=cb, sink_size=step_size)
t1 = time.time()
args.logger.info('training time used={:.2f}s'.format(t1 - t0))
last_metric = 'last_metric[{}]'.format(state_cb.best_acc)
args.logger.info(last_metric)

is_cloud = args.enable_modelarts
if is_cloud:
ip = os.getenv("BATCH_TASK_CURRENT_HOST_IP")
else:
ip = socket.gethostbyname(socket.gethostname())
args.logger.info('ip[{}], mean_fps[{:.2f}]'.format(ip, state_cb.mean_fps))

if args.open_profiler:
profiler.analyse()

if __name__ == '__main__':
train_net()

+ 108
- 0
xbm/core/callback_ViT.py View File

@@ -0,0 +1,108 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""callbacks"""

import time
import numpy as np
from mindspore.train.callback import Callback
from mindspore.common.tensor import Tensor

class StateMonitor(Callback):
"""StateMonitor"""
def __init__(self, data_size, tot_batch_size=None, lrs=None,
eval_interval=None, eval_offset=None, eval_engine=None, logger=None):
super(StateMonitor, self).__init__()
self.data_size = data_size
self.tot_batch_size = tot_batch_size
self.lrs = lrs
self.epoch_num = 0
self.loss = 0
self.eval_interval = eval_interval
self.eval_offset = eval_offset
self.eval_engine = eval_engine
self.best_acc = -1
self.best_acc_top5 = -1
self.best_i2t_recall = -1
self.best_t2i_recall = -1
self.mean_fps = 0.0
self.print = print
if logger is not None:
self.print = logger


def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs

if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]

if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())

self.loss = loss

def epoch_begin(self, run_context):
self.epoch_time = time.time()

def epoch_end(self, run_context):
epoch_seconds = (time.time() - self.epoch_time)
per_step_seconds = epoch_seconds / self.data_size

print_str = "epoch[{}]".format(self.epoch_num)
print_str += ', epoch time: {:.2f}s'.format(epoch_seconds)
print_str += ', per step time: {:.4f}s'.format(per_step_seconds)
print_str += ', loss={:.6f}'.format(self.loss)

if self.lrs is not None:
lr = self.lrs[(self.epoch_num + 1) * self.data_size - 1]
print_str += ', lr={:.6f}'.format(lr)

if self.tot_batch_size is not None:
fps = self.tot_batch_size * self.data_size / epoch_seconds
self.mean_fps = (self.mean_fps * self.epoch_num + fps) / (self.epoch_num + 1)
print_str += ', fps={:.2f}'.format(fps)

if (self.epoch_num + 1) % self.eval_interval == self.eval_offset:
eval_start = time.time()
self.eval_engine.eval()
output = self.eval_engine.get_result()
eval_seconds = time.time() - eval_start
if output is not None:
if isinstance(output, list):
print_str += ', top1 accuracy={:.6f}'.format(float(output[0]))
print_str += ', top5 accuracy={:.6f}'.format(float(output[1]))
print_str += ', i2t_recall={:.6f}'.format(float(output[2]))
print_str += ', t2i_recall={:.6f}'.format(float(output[3]))
print_str += ', eval_cost={:.2f}'.format(eval_seconds)

if float(output[0]) > self.best_acc:
self.best_acc = float(output[0])
if float(output[1]) > self.best_acc_top5:
self.best_acc_top5 = float(output[1])
if float(output[2]) > self.best_i2t_recall:
self.best_i2t_recall = float(output[2])
if float(output[3]) > self.best_t2i_recall:
self.best_t2i_recall = float(output[3])
else:
print_str += ', accuracy={:.6f}'.format(float(output))
print_str += ', eval_cost={:.2f}'.format(eval_seconds)

if float(output) > self.best_acc:
self.best_acc = float(output)

self.print(print_str)
self.epoch_num += 1

+ 124
- 0
xbm/core/cross_entropy_ViT.py View File

@@ -0,0 +1,124 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""loss functions"""

from mindspore import nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
try:
from mindspore.nn.loss.loss import Loss
except ImportError:
try:
from mindspore.nn.loss.loss import LossBase as Loss
except ImportError:
from mindspore.nn.loss.loss import _Loss as Loss

from mindspore.ops import functional as F
from mindspore.ops import operations as P


class CrossEntropySmooth(Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000, aux_factor=0.4):
super().__init__()
self.aux_factor = aux_factor
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

def construct(self, logits, label):
if isinstance(logits, tuple):
logit, aux_logit = logits
else:
logit, aux_logit = logits, None

if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)

loss = self.ce(logit, label)
if aux_logit is not None:
loss = loss + self.aux_factor * self.ce(aux_logit, label)
return loss


class CrossEntropySmoothMixup(Loss):
"""CrossEntropy"""
def __init__(self, reduction='mean', smooth_factor=0., num_classes=1000):
super().__init__()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = 1.0 * smooth_factor / (num_classes - 2)
self.cross_entropy = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

def construct(self, logit, label):
off_label = P.Select()(P.Equal()(label, 0.0), \
P.Fill()(mstype.float32, P.Shape()(label), self.off_value), \
P.Fill()(mstype.float32, P.Shape()(label), 0.0))

label = self.on_value * label + off_label
loss = self.cross_entropy(logit, label)
return loss


class CrossEntropyIgnore(Loss):
"""CrossEntropyIgnore"""
def __init__(self, num_classes=21, ignore_label=255):
super().__init__()
self.one_hot = P.OneHot(axis=-1)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.cast = P.Cast()
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.not_equal = P.NotEqual()
self.num_cls = num_classes
self.ignore_label = ignore_label
self.mul = P.Mul()
self.sum = P.ReduceSum(False)
self.div = P.RealDiv()
self.transpose = P.Transpose()
self.reshape = P.Reshape()

def construct(self, logits, labels):
labels_int = self.cast(labels, mstype.int32)
labels_int = self.reshape(labels_int, (-1,))
logits_ = self.transpose(logits, (0, 2, 3, 1))
logits_ = self.reshape(logits_, (-1, self.num_cls))
weights = self.not_equal(labels_int, self.ignore_label)
weights = self.cast(weights, mstype.float32)
one_hot_labels = self.one_hot(labels_int, self.num_cls, self.on_value, self.off_value)
loss = self.ce(logits_, one_hot_labels)
loss = self.mul(weights, loss)
loss = self.div(self.sum(loss), self.sum(weights))
return loss


def get_loss(loss_name, args):
"""get_loss"""
loss = None
if loss_name == 'ce_smooth':
loss = CrossEntropySmooth(smooth_factor=args.label_smooth_factor,
num_classes=args.class_num,
aux_factor=args.aux_factor)
elif loss_name == 'ce_smooth_mixup':
loss = CrossEntropySmoothMixup(smooth_factor=args.label_smooth_factor,
num_classes=args.class_num)
elif loss_name == 'ce_ignore':
loss = CrossEntropyIgnore(num_classes=args.class_num,
ignore_label=args.ignore_label)
else:
raise NotImplementedError

return loss

+ 105
- 0
xbm/core/eval_engine_ViT.py View File

@@ -0,0 +1,105 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval engine"""

from mindspore import Tensor
import mindspore.common.dtype as mstype

from src.metric import ClassifyCorrectWithCache, ClassifyCorrectCell, DistAccuracy

class BasicEvalEngine():
"""BasicEvalEngine"""
def __init__(self):
pass

@property
def metric(self):
return None

@property
def eval_network(self):
return None

def compile(self, sink_size=-1):
pass

def eval(self):
pass

def set_model(self, model):
self.model = model

def get_result(self):
return None

class ImageNetCacheEvelEngine(BasicEvalEngine):
"""ImageNetCacheEvelEngine"""
def __init__(self, net, eval_dataset, args):
super().__init__()
self.dist_eval_network = ClassifyCorrectWithCache(net, eval_dataset)
self.outputs = None
self.args = args

def compile(self, sink_size=-1):
index = Tensor(0, mstype.int32)
self.dist_eval_network.set_train(False)
self.dist_eval_network.compile(index)

def eval(self):
index = Tensor(0, mstype.int32)
output = self.dist_eval_network(index)
output = output.asnumpy() / 50000
self.outputs = {"acc": output}

def get_result(self):
return self.outputs["acc"]


class ImageNetEvelEngine(BasicEvalEngine):
"""ImageNetEvelEngine"""
def __init__(self, net, eval_dataset, args):
super().__init__()
self.eval_dataset = eval_dataset
self.dist_eval_network = ClassifyCorrectCell(net)
self.args = args
self.outputs = None
self.model = None

@property
def metric(self):
return {'acc': DistAccuracy(batch_size=self.args.eval_batch_size, device_num=self.args.device_num)}

@property
def eval_network(self):
return self.dist_eval_network

def eval(self):
self.outputs = self.model.eval(self.eval_dataset)

def get_result(self):
return self.outputs["acc"]

def get_eval_engine(engine_name, net, eval_dataset, args):
"""get_eval_engine"""
if engine_name == '':
eval_engine = BasicEvalEngine()
elif engine_name == "imagenet":
eval_engine = ImageNetEvelEngine(net, eval_dataset, args)
elif engine_name == "imagenet_cache":
eval_engine = ImageNetCacheEvelEngine(net, eval_dataset, args)
else:
raise NotImplementedError

return eval_engine

+ 80
- 0
xbm/core/logging_ViT.py View File

@@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""logging"""

import logging
import os
import sys
from datetime import datetime

logger_name = 'mindspore-benchmark'


class LOGGER(logging.Logger):
"""
LOGGER
"""
def __init__(self, logger_name_local, rank=0):
super().__init__(logger_name_local)
self.log_fn = None
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s', "%Y-%m-%d %H:%M:%S")
console.setFormatter(formatter)
self.addHandler(console)

def setup_logging_file(self, log_dir, rank=0):
"""setup_logging_file"""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
self.log_fn = log_fn

def info(self, msg, *args, **kwargs):
"""info"""
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)

def save_args(self, args):
"""save_args"""
self.info('Args:')
if isinstance(args, (list, tuple)):
for value in args:
message = '--> {}'.format(value)
self.info(message)
else:
if isinstance(args, dict):
args_dict = args
else:
args_dict = vars(args)
for key in args_dict.keys():
message = '--> {}: {}'.format(key, args_dict[key])
self.info(message)
self.info('')


def get_logger(path, rank=0):
"""get_logger"""
logger = LOGGER(logger_name, rank)
logger.setup_logging_file(path, rank)
return logger

+ 93
- 0
xbm/core/lr_generator_ViT.py View File

@@ -0,0 +1,93 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""

import math
import numpy as np

def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr

def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, \
total_epochs, steps_per_epoch, lr_decay_mode, poly_power=2.0):
"""
generate learning rate array

Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default

Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = int(steps_per_epoch * warmup_epochs)
if lr_decay_mode == 'steps':
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
elif lr_decay_mode == 'poly':
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max - lr_end) * base ** poly_power + lr_end
lr = max(lr, 0.0)
lr_each_step.append(lr)
elif lr_decay_mode == 'cosine':
decay_steps = total_steps - warmup_steps
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
cur_step = i + 1 - warmup_steps
lr = lr_max * (1 + math.cos(math.pi * cur_step / decay_steps)) / 2
lr_each_step.append(lr)
else:
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)

current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]

return learning_rate

+ 0
- 40
xbm/core/optimizer.py View File

@@ -1,40 +0,0 @@
from mindspore.nn.optim import AdamWeightDecay
from mindspore.common import Parameter, ParameterTuple
from mindspore.common.initializer import initializer
import mindspore.common.dtype as mstype


class FP32StateAdamWeightDecay(AdamWeightDecay):
r"""
This class is almost same with the mindspore's AdamWeightDecay implements, the
only difference is the optimizer's state will be always initialized with float32,
where the original AdamWeightDecay will initialize the optimizer's state with float16,
if the parameters are initialized with fp16.
This setting will avoid overflow in training PanGu-Alpha model using fp16.
"""

def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
super(FP32StateAdamWeightDecay, self).__init__(params, learning_rate=learning_rate,
beta1=beta1,
beta2=beta2,
eps=eps,
weight_decay=weight_decay)

self.moments1 = self.clone_state(self.parameters, prefix='adam_m', init='zeros')
self.moments2 = self.clone_state(self.parameters, prefix='adam_v', init='zeros')

def clone_state(self, parameter_tuple, prefix, init):
r"""
parameter_tuple: ParameterTuple. The parameters of the network
prefix: str. The prefix name of the parameters
init: str. The initialization method
"""
new = []
for old_param in parameter_tuple:
new_state = Parameter(initializer(init, shape=old_param.shape, dtype=mstype.float32))
new_state.param_info = old_param.param_info.clone()
new_state.is_init = False
new_state.set_data(initializer(init, shape=old_param.shape, dtype=mstype.float32))
new_state.name = prefix + '.' + new_state.name
new.append(new_state)
return ParameterTuple(new)

+ 214
- 0
xbm/core/optimizer_ViT.py View File

@@ -0,0 +1,214 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Gradient clipping wrapper for optimizers."""

import numpy as np

from mindspore._checkparam import Validator as validator
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops import composite as C

from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Rel
from mindspore.nn.optim import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register


def _check_param_value(beta1, beta2, eps, prim_name):
"""Check the type of inputs."""
validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
validator.check_positive_float(eps, "eps", prim_name)


_grad_scale = C.MultitypeFuncGraph("grad_scale")
op_mul = P.Mul()
map_ = C.Map()


@_grad_scale.register("Number", "Tensor")
def tensor_grad_scale(scale, grad):
"""Get grad with scale."""
if scale == 1.0:
return grad
return op_mul(grad, F.cast(scale, F.dtype(grad)))


@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale_with_tensor(scale, grad):
"""Get grad with scale."""
return op_mul(grad, F.cast(scale, F.dtype(grad)))


def scale_grad(gradients, reciprocal_scale):
gradients = map_(F.partial(_grad_scale, reciprocal_scale), gradients)
return gradients


_adam_opt = C.MultitypeFuncGraph("adam_opt")
_scaler_one = Tensor(1, mstype.int32)


@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op(beta1_power, beta2_power, beta1, beta2, eps, lr, weight_decay, param, \
m, v, gradient, decay_flag, optim_filter):
"""
Update parameters.

Args:
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate.
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
param (Tensor): Parameters.
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Applies weight decay or not.
optim_filter (bool): Applies parameter update or not.

Returns:
Tensor, the new value of v after updating.
"""
if optim_filter:
# op_mul = P.Mul(), defined output
op_square = P.Square()
op_sqrt = P.Sqrt()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()

param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)

next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta1, gradient_fp32)

next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32))

regulate_m = next_m / (_scaler_one - beta1_power)
regulate_v = next_v / (_scaler_one - beta2_power)

update = regulate_m / (eps + op_sqrt(regulate_v))
if decay_flag:
update = op_mul(weight_decay, param_fp32) + update

update_with_lr = op_mul(lr, update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))

next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))

return op_cast(next_param, F.dtype(param))
return gradient


class AdamW(Optimizer):
"""
Implements the gradient clipping by norm for a AdamWeightDecay optimizer.
"""
@opt_init_args_register
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, \
weight_decay=0.0, loss_scale=1.0, clip=False):
super(AdamW, self).__init__(learning_rate, params, weight_decay)
_check_param_value(beta1, beta2, eps, self.cls_name)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
self.hyper_map = C.HyperMap()
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")

self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
self.clip = clip

def construct(self, gradients):
lr = self.get_lr()
gradients = scale_grad(gradients, self.reciprocal_scale)
if self.clip:
gradients = C.clip_by_global_norm(gradients, 5.0, None)

beta1_power = self.beta1_power * self.beta1
self.beta1_power = beta1_power
beta2_power = self.beta2_power * self.beta2
self.beta2_power = beta2_power

if self.is_group:
if self.is_group_lr:
optim_result = self.hyper_map(F.partial(_adam_opt, beta1_power, beta2_power, \
self.beta1, self.beta2, self.eps),
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, beta1_power, beta2_power, \
self.beta1, self.beta2, self.eps, lr),
self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, beta1_power, beta2_power, self.beta1, self.beta2, \
self.eps, lr, self.weight_decay),
self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
if self.use_parallel:
self.broadcast_params(optim_result)
return optim_result

def paramter_group(network, weight_decay, no_weight_decay_filter, gc_flag):
"""paramter_group"""
filter_len = len(no_weight_decay_filter)
if filter_len > 0:
decayed_params = []
no_decayed_params = []
for param in network.trainable_params():
if all([key not in param.name for key in no_weight_decay_filter]):
decayed_params.append(param)
else:
no_decayed_params.append(param)

group_params = [{'params': decayed_params, 'weight_decay': weight_decay, 'grad_centralization': gc_flag},
{'params': no_decayed_params},
{'order_params': network.trainable_params()}]
else:
group_params = [{'params': network.trainable_params(), \
'weight_decay': weight_decay, 'grad_centralization': gc_flag},
{'order_params': network.trainable_params()}]

return group_params

def get_optimizer(optimizer_name, network, lrs, args):
no_weight_decay_filter = [x for x in args.no_weight_decay_filter.split(",") if len(x) > 0]
group_params = paramter_group(network, args.weight_decay, no_weight_decay_filter, bool(args.gc_flag))

if optimizer_name == 'adamw':
opt = AdamW(group_params, lrs, args.beta1, args.beta2, loss_scale=args.loss_scale)
else:
raise NotImplementedError

return opt, group_params

+ 264
- 0
xbm/datasets/autoaugment_ViT.py View File

@@ -0,0 +1,264 @@
# MIT License

# Copyright (c) 2018 Philip Popien

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# ============================================================================

"""
This code is based on https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
"""

import random
from PIL import Image, ImageEnhance, ImageOps
import numpy as np


class ImageNetPolicy():
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
Example:
>>> policy = ImageNetPolicy()
>>> transformed = policy(image)
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),

SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),

SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),

SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),

SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
]

def __call__(self, img, policy_idx=None):
if policy_idx is None or not isinstance(policy_idx, int):
policy_idx = random.randint(0, len(self.policies) - 1)
else:
policy_idx = policy_idx % len(self.policies)
return self.policies[policy_idx](img)

def __repr__(self):
return "AutoAugment ImageNet Policy"


class CIFAR10Policy():
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
Example:
>>> policy = CIFAR10Policy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> CIFAR10Policy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),

SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),

SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),

SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),

SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
]

def __call__(self, img, policy_idx=None):
if policy_idx is None or not isinstance(policy_idx, int):
policy_idx = random.randint(0, len(self.policies) - 1)
else:
policy_idx = policy_idx % len(self.policies)
return self.policies[policy_idx](img)

def __repr__(self):
return "AutoAugment CIFAR10 Policy"


class SVHNPolicy():
""" Randomly choose one of the best 25 Sub-policies on SVHN.
Example:
>>> policy = SVHNPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> SVHNPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),

SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),

SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),

SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),

SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
]

def __call__(self, img, policy_idx=None):
if policy_idx is None or not isinstance(policy_idx, int):
policy_idx = random.randint(0, len(self.policies) - 1)
else:
policy_idx = policy_idx % len(self.policies)
return self.policies[policy_idx](img)

def __repr__(self):
return "AutoAugment SVHN Policy"


class SubPolicy():
"""
Randomly choose one of the best 14 Sub-policies on SubPolicy.
"""
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
ranges = {
"shearX": np.linspace(0, 0.3, 10),
"shearY": np.linspace(0, 0.3, 10),
"translateX": np.linspace(0, 150 / 331, 10),
"translateY": np.linspace(0, 150 / 331, 10),
"rotate": np.linspace(0, 30, 10),
"color": np.linspace(0.0, 0.9, 10),
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
"solarize": np.linspace(256, 0, 10),
"contrast": np.linspace(0.0, 0.9, 10),
"sharpness": np.linspace(0.0, 0.9, 10),
"brightness": np.linspace(0.0, 0.9, 10),
"autocontrast": [0] * 10,
"equalize": [0] * 10,
"invert": [0] * 10
}

# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)

# pylint: disable = unnecessary-lambda
func = {
"shearX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}

self.p1 = p1
self.operation1 = func[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
self.p2 = p2
self.operation2 = func[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]

def __call__(self, img):
if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
return img

+ 171
- 0
xbm/datasets/dataset_ViT.py View File

@@ -0,0 +1,171 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create train or eval dataset."""

import os
import warnings
from io import BytesIO
from PIL import Image
import numpy as np

import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.py_transforms as P
from mindspore.dataset.vision.utils import Inter

from .autoaugment import ImageNetPolicy

warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)

class ToNumpy:
def __init__(self):
pass

def __call__(self, img):
return np.asarray(img)

def create_dataset(dataset_path,
do_train,
image_size=224,
interpolation='BILINEAR',
crop_min=0.05,
repeat_num=1,
batch_size=32,
num_workers=12,
autoaugment=False,
mixup=0.0,
num_classes=1001):
"""create_dataset"""

if hasattr(Inter, interpolation):
interpolation = getattr(Inter, interpolation)
else:
interpolation = Inter.BILINEAR
print('cannot find interpolation_type: {}, use {} instead'.format(interpolation, 'BILINEAR'))

device_num = int(os.getenv("RANK_SIZE", '1'))
rank_id = int(os.getenv('RANK_ID', '0'))

if do_train:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True,
num_shards=device_num, shard_id=rank_id)
else:
batch_per_step = batch_size * device_num
print("eval batch per step: {}".format(batch_per_step))
if batch_per_step < 50000:
if 50000 % batch_per_step == 0:
num_padded = 0
else:
num_padded = batch_per_step - (50000 % batch_per_step)
else:
num_padded = batch_per_step - 50000
print("eval dataset num_padded: {}".format(num_padded))

if num_padded != 0:
# padded_with_decode
white_io = BytesIO()
Image.new('RGB', (image_size, image_size), (255, 255, 255)).save(white_io, 'JPEG')
padded_sample = {
'image': np.array(bytearray(white_io.getvalue()), dtype='uint8'),
'label': np.array(-1, np.int32)
}
sample = [padded_sample for x in range(num_padded)]
ds_pad = de.PaddedDataset(sample)
ds_imagefolder = de.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers)
ds = ds_pad + ds_imagefolder
distribute_sampler = de.DistributedSampler(num_shards=device_num, shard_id=rank_id, \
shuffle=False, num_samples=None)
ds.use_sampler(distribute_sampler)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, \
shuffle=False, num_shards=device_num, shard_id=rank_id)
print("eval dataset size: {}".format(ds.get_dataset_size()))

mean = [0.485*255, 0.456*255, 0.406*255]
std = [0.229*255, 0.224*255, 0.225*255]

# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(crop_min, 1.0), \
ratio=(0.75, 1.333), interpolation=interpolation),
C.RandomHorizontalFlip(prob=0.5),
]
if autoaugment:
trans += [
P.ToPIL(),
ImageNetPolicy(),
ToNumpy(),
]
trans += [
C.Normalize(mean=mean, std=std),
C.HWC2CHW(),
]
else:
resize = int(int(image_size / 0.875 / 16 + 0.5) * 16)
print('eval, resize:{}'.format(resize))
trans = [
C.Decode(),
C.Resize(resize, interpolation=interpolation),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]

type_cast_op = C2.TypeCast(mstype.int32)

ds = ds.repeat(repeat_num)
ds = ds.map(input_columns="image", num_parallel_workers=num_workers, operations=trans, python_multiprocessing=True)
ds = ds.map(input_columns="label", num_parallel_workers=num_workers, operations=type_cast_op)

if do_train and mixup > 0:
one_hot_encode = C2.OneHot(num_classes)
ds = ds.map(operations=one_hot_encode, input_columns=["label"])

ds = ds.batch(batch_size, drop_remainder=True)

if do_train and mixup > 0:
trans_mixup = C.MixUpBatch(alpha=mixup)
ds = ds.map(input_columns=["image", "label"], num_parallel_workers=num_workers, operations=trans_mixup)

return ds


def get_dataset(dataset_name, do_train, dataset_path, args):
"""get_dataset"""
if dataset_name == "imagenet":
if do_train:
data = create_dataset(dataset_path=dataset_path,
do_train=True,
image_size=args.train_image_size,
interpolation=args.interpolation,
autoaugment=args.autoaugment,
mixup=args.mixup,
crop_min=args.crop_min,
batch_size=args.batch_size,
num_workers=args.train_num_workers)
else:
data = create_dataset(dataset_path=dataset_path,
do_train=False,
image_size=args.eval_image_size,
interpolation=args.interpolation,
batch_size=args.eval_batch_size,
num_workers=args.eval_num_workers)
else:
raise NotImplementedError

return data

+ 506
- 0
xbm/models/vit.py View File

@@ -0,0 +1,506 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Vision Transformer implementation."""

from importlib import import_module
from easydict import EasyDict as edict
import numpy as np

import mindspore
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.nn import Cell, Dense, Dropout, SequentialCell
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
from mindspore import Tensor

MIN_NUM_PATCHES = 4

class VitConfig:
"""
VitConfig
"""
def __init__(self, configs):
self.configs = configs

# network init
self.network_norm = mindspore.nn.LayerNorm((configs.normalized_shape,))
self.network_init = mindspore.common.initializer.Normal(sigma=1.0)
self.network_dropout_rate = 0.1
self.network_pool = 'cls'
self.network = ViT

# stem
self.stem_init = mindspore.common.initializer.XavierUniform()
self.stem = VitStem

# body
self.body_norm = mindspore.nn.LayerNorm
self.body_drop_path_rate = 0.1
self.body = Transformer

# body attention
self.attention_init = mindspore.common.initializer.XavierUniform()
self.attention_activation = mindspore.nn.Softmax()
self.attention_dropout_rate = 0.1
self.attention = Attention

# body feedforward
self.feedforward_init = mindspore.common.initializer.XavierUniform()
self.feedforward_activation = mindspore.nn.GELU()
self.feedforward_dropout_rate = 0.1
self.feedforward = FeedForward

# head
self.head = origin_head
self.head_init = mindspore.common.initializer.XavierUniform()
self.head_dropout_rate = 0.1
self.head_norm = mindspore.nn.LayerNorm((configs.normalized_shape,))
self.head_activation = mindspore.nn.GELU()


class DropPath(Cell):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""

def __init__(self, drop_prob=None, seed=0):
super(DropPath, self).__init__()
self.keep_prob = 1 - drop_prob
seed = min(seed, 0) # always be 0
self.rand = P.UniformReal(seed=seed) # seed must be 0, if set to other value, it's not rand for multiple call
self.shape = P.Shape()
self.floor = P.Floor()

def construct(self, x):
if self.training:
x_shape = self.shape(x) # B N C
random_tensor = self.rand((x_shape[0], 1, 1))
random_tensor = random_tensor + self.keep_prob
random_tensor = self.floor(random_tensor)
x = x / self.keep_prob
x = x * random_tensor
return x


class BatchDense(Cell):
"""BatchDense module."""

def __init__(self, in_features, out_features, initialization, has_bias=True):
super().__init__()
self.out_features = out_features
self.dense = Dense(in_features, out_features, has_bias=has_bias)
self.dense.weight.set_data(initializer(initialization, [out_features, in_features]))
self.reshape = P.Reshape()

def construct(self, x):
bs, seq_len, d_model = x.shape
out = self.reshape(x, (bs * seq_len, d_model))
out = self.dense(out)
out = self.reshape(out, (bs, seq_len, self.out_features))
return out


class ResidualCell(Cell):
"""Cell which implements x + f(x) function."""
def __init__(self, cell):
super().__init__()
self.cell = cell

def construct(self, x, **kwargs):
return self.cell(x, **kwargs) + x


def pretrain_head(vit_config):
"""Head for ViT pretraining."""
d_model = vit_config.configs.d_model
mlp_dim = vit_config.configs.mlp_dim
num_classes = vit_config.configs.num_classes

dropout_rate = vit_config.head_dropout_rate
initialization = vit_config.head_init
normalization = vit_config.head_norm
activation = vit_config.head_activation

dense1 = Dense(d_model, mlp_dim)
dense1.weight.set_data(initializer(initialization, [mlp_dim, d_model]))
dense2 = Dense(mlp_dim, num_classes)
dense2.weight.set_data(initializer(initialization, [num_classes, mlp_dim]))

return SequentialCell([
normalization,
dense1,
activation,
Dropout(keep_prob=(1. - dropout_rate)),
dense2])


def origin_head(vit_config):
"""Head for ViT pretraining."""
d_model = vit_config.configs.d_model
num_classes = vit_config.configs.num_classes
initialization = vit_config.head_init
dense = Dense(d_model, num_classes)
dense.weight.set_data(initializer(initialization, [num_classes, d_model]))
return SequentialCell([dense])


class VitStem(Cell):
"""Stem layer for ViT."""

def __init__(self, vit_config):
super().__init__()
d_model = vit_config.configs.d_model
patch_size = vit_config.configs.patch_size
image_size = vit_config.configs.image_size
initialization = vit_config.stem_init
channels = 3

assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches {num_patches} is too small'
patch_dim = channels * patch_size ** 2

self.patch_size = patch_size
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.patch_to_embedding = BatchDense(patch_dim, d_model, initialization, has_bias=True)

def construct(self, img):
p = self.patch_size
bs, channels, h, w = img.shape
x = self.reshape(img, (bs, channels, h // p, p, w // p, p))
x = self.transpose(x, (0, 2, 4, 1, 3, 5))
x = self.reshape(x, (bs, (h//p)*(w//p), channels*p*p))
x = self.patch_to_embedding(x)
return x


class ViT(Cell):
"""Vision Transformer implementation."""

def __init__(self, vit_config):
super().__init__()

d_model = vit_config.configs.d_model
patch_size = vit_config.configs.patch_size
image_size = vit_config.configs.image_size

initialization = vit_config.network_init
pool = vit_config.network_pool
dropout_rate = vit_config.network_dropout_rate
norm = vit_config.network_norm

stem = vit_config.stem(vit_config)
body = vit_config.body(vit_config)
head = vit_config.head(vit_config)

assert pool in {'cls', 'mean'}, 'pool type must be either cls or mean'
num_patches = (image_size // patch_size) ** 2

if pool == "cls":
self.cls_token = Parameter(initializer(initialization, (1, 1, d_model)),
name='cls', requires_grad=True)
self.pos_embedding = Parameter(initializer(initialization, (1, num_patches + 1, d_model)),
name='pos_embedding', requires_grad=True)
self.tile = P.Tile()
self.cat_1 = P.Concat(axis=1)
else:
self.pos_embedding = Parameter(initializer(initialization, (1, num_patches, d_model)),
name='pos_embedding', requires_grad=True)
self.mean = P.ReduceMean(keep_dims=False)
self.pool = pool

self.cast = P.Cast()
self.dropout = Dropout(keep_prob=(1. - dropout_rate))
self.stem = stem
self.body = body
self.head = head
self.norm = norm

def construct(self, img):
x = self.stem(img)
bs, seq_len, _ = x.shape

if self.pool == "cls":
cls_tokens = self.tile(self.cls_token, (bs, 1, 1))
x = self.cat_1((cls_tokens, x)) # now x has shape = (bs, seq_len+1, d)
x += self.pos_embedding[:, :(seq_len + 1)]
else:
x += self.pos_embedding[:, :seq_len]

y = self.cast(x, mstype.float32)
y = self.dropout(y)
x = self.cast(y, x.dtype)

x = self.body(x)

if self.norm is not None:
x = self.norm(x)

if self.pool == "cls":
x = x[:, 0]
else:
x = self.mean(x, (-2,))

return self.head(x)


class Attention(Cell):
"""Attention layer implementation."""

def __init__(self, vit_config):
super().__init__()
d_model = vit_config.configs.d_model
dim_head = vit_config.configs.dim_head
heads = vit_config.configs.heads

initialization = vit_config.attention_init
activation = vit_config.attention_activation
dropout_rate = vit_config.attention_dropout_rate

inner_dim = heads * dim_head
self.dim_head = dim_head
self.heads = heads
self.scale = Tensor([dim_head ** -0.5])

self.to_q = Dense(d_model, inner_dim, has_bias=True)
self.to_q.weight.set_data(initializer(initialization, [inner_dim, d_model]))
self.to_k = Dense(d_model, inner_dim, has_bias=True)
self.to_k.weight.set_data(initializer(initialization, [inner_dim, d_model]))
self.to_v = Dense(d_model, inner_dim, has_bias=True)
self.to_v.weight.set_data(initializer(initialization, [inner_dim, d_model]))

self.to_out = Dense(inner_dim, d_model, has_bias=True)
self.to_out.weight.set_data(initializer(initialization, [inner_dim, d_model]))
self.dropout = Dropout(1 - dropout_rate)

self.activation = activation

#auxiliary functions
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.cast = P.Cast()
self.mul = P.Mul()
self.q_matmul_k = P.BatchMatMul(transpose_b=True)
self.attn_matmul_v = P.BatchMatMul()
self.softmax_nz = True

def construct(self, x):
'''x size - BxNxd_model'''
bs, seq_len, d_model, h, d = x.shape[0], x.shape[1], x.shape[2], self.heads, self.dim_head

x_2d = self.reshape(x, (-1, d_model))
q, k, v = self.to_q(x_2d), self.to_k(x_2d), self.to_v(x_2d)

if self.softmax_nz:
q = self.reshape(q, (bs, seq_len, h, d))
q = self.transpose(q, (0, 2, 1, 3))
q = self.cast(q, mstype.float32)
q = self.mul(q, self.scale)

k = self.reshape(k, (bs, seq_len, h, d))
k = self.transpose(k, (0, 2, 1, 3))
v = self.reshape(v, (bs, seq_len, h, d))
v = self.transpose(v, (0, 2, 1, 3))

q = self.cast(q, k.dtype)
attn_scores = self.q_matmul_k(q, k) #bs x h x seq_len x seq_len
attn_scores = self.cast(attn_scores, x.dtype)
attn_scores = self.activation(attn_scores)
else:
q = self.reshape(q, (bs, seq_len, h, d))
q = self.transpose(q, (0, 2, 1, 3))
k = self.reshape(k, (bs, seq_len, h, d))
k = self.transpose(k, (0, 2, 1, 3))
v = self.reshape(v, (bs, seq_len, h, d))
v = self.transpose(v, (0, 2, 1, 3))

attn_scores = self.q_matmul_k(q, k) #bs x h x seq_len x seq_len
attn_scores = self.cast(attn_scores, mstype.float32)
attn_scores = self.mul(attn_scores, self.scale)
attn_scores = self.cast(attn_scores, x.dtype)
attn_scores = self.activation(attn_scores)

out = self.attn_matmul_v(attn_scores, v) #bs x h x seq_len x dim_head
out = self.transpose(out, (0, 2, 1, 3))
out = self.reshape(out, (bs*seq_len, h*d))
out = self.to_out(out)
out = self.reshape(out, (bs, seq_len, d_model))
#out = self.dropout(out)
y = self.cast(out, mstype.float32)
y = self.dropout(y)
out = self.cast(y, out.dtype)
#out = self.reshape(out, (bs, seq_len, d_model))
return out


class FeedForward(Cell):
"""FeedForward layer implementation."""

def __init__(self, vit_config):
super().__init__()

d_model = vit_config.configs.d_model
hidden_dim = vit_config.configs.mlp_dim

initialization = vit_config.feedforward_init
activation = vit_config.feedforward_activation
dropout_rate = vit_config.feedforward_dropout_rate

self.ff1 = BatchDense(d_model, hidden_dim, initialization)
self.activation = activation
self.dropout = Dropout(keep_prob=1.-dropout_rate)
self.ff2 = BatchDense(hidden_dim, d_model, initialization)
self.cast = P.Cast()

def construct(self, x):
y = self.ff1(x)
y = self.cast(y, mstype.float32)
y = self.activation(y)
y = self.dropout(y)
y = self.cast(y, x.dtype)
y = self.ff2(y)
y = self.cast(y, mstype.float32)
y = self.dropout(y)
y = self.cast(y, x.dtype)
return y


class Transformer(Cell):
"""Transformer implementation."""

def __init__(self, vit_config):
super().__init__()

depth = vit_config.configs.depth
drop_path_rate = vit_config.body_drop_path_rate

dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)]
att_seeds = [np.random.randint(1024) for _ in range(depth)]
mlp_seeds = [np.random.randint(1024) for _ in range(depth)]

layers = []
for i in range(depth):
normalization = vit_config.body_norm((vit_config.configs.normalized_shape,))
normalization2 = vit_config.body_norm((vit_config.configs.normalized_shape,))
attention = vit_config.attention(vit_config)
feedforward = vit_config.feedforward(vit_config)

if drop_path_rate > 0:
layers.append(
SequentialCell([
ResidualCell(SequentialCell([normalization,
attention,
DropPath(dpr[i], att_seeds[i])])),
ResidualCell(SequentialCell([normalization2,
feedforward,
DropPath(dpr[i], mlp_seeds[i])]))
])
)
else:
layers.append(
SequentialCell([
ResidualCell(SequentialCell([normalization,
attention])),
ResidualCell(SequentialCell([normalization2,
feedforward]))
])
)

self.layers = SequentialCell(layers)

def construct(self, x):
return self.layers(x)


def load_function(func_name):
"""Load function using its name."""
modules = func_name.split(".")
if len(modules) > 1:
module_path = ".".join(modules[:-1])
name = modules[-1]
module = import_module(module_path)
return getattr(module, name)
return func_name


vit_cfg = edict({
'd_model': 768,
'depth': 12,
'heads': 12,
'mlp_dim': 3072,
'dim_head': 64,
'patch_size': 32,
'normalized_shape': 768,
'image_size': 224,
'num_classes': 1001,
})


def vit_base_patch16(args):
"""vit_base_patch16"""
vit_cfg.d_model = 768
vit_cfg.depth = 12
vit_cfg.heads = 12
vit_cfg.mlp_dim = 3072
vit_cfg.dim_head = vit_cfg.d_model // vit_cfg.heads
vit_cfg.patch_size = 16
vit_cfg.normalized_shape = vit_cfg.d_model
vit_cfg.image_size = args.train_image_size
vit_cfg.num_classes = args.class_num

if args.vit_config_path != '':
print("get vit_config_path")
vit_config = load_function(args.vit_config_path)(vit_cfg)
else:
print("get default_vit_cfg")
vit_config = VitConfig(vit_cfg)

model = vit_config.network(vit_config)
return model


def vit_base_patch32(args):
"""vit_base_patch32"""
vit_cfg.d_model = 768
vit_cfg.depth = 12
vit_cfg.heads = 12
vit_cfg.mlp_dim = 3072
vit_cfg.dim_head = vit_cfg.d_model // vit_cfg.heads
vit_cfg.patch_size = 32
vit_cfg.normalized_shape = vit_cfg.d_model
vit_cfg.image_size = args.train_image_size
vit_cfg.num_classes = args.class_num

if args.vit_config_path != '':
print("get vit_config_path")
vit_config = load_function(args.vit_config_path)(vit_cfg)
else:
print("get default_vit_cfg")
vit_config = VitConfig(vit_cfg)

model = vit_config.network(vit_config)

return model

def get_network(backbone_name, args):
"""get_network"""
if backbone_name == 'vit_base_patch32':
backbone = vit_base_patch32(args=args)
elif backbone_name == 'vit_base_patch16':
backbone = vit_base_patch16(args=args)
else:
raise NotImplementedError
return backbone

Loading…
Cancel
Save