|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import mindspore as ms
- import os
- import wget
-
- model_list = ['conformer_overflow_dw', \
- 'conformer_overflow', \
- 'conformer', \
- 'mobilenetv3_large_100', \
- 'resnet50', \
- 'vit_base_patch32_224', \
- 'swin_tiny_patch4_window7_224', \
- 'swinv2_tiny_patch4_window8_256', \
- 'beit_base_patch16_224',\
- 'beit_base_patch16_384',\
- 'beit_large_patch16_224',\
- 'beit_large_patch16_384',\
- 'beit_large_patch16_512',\
- 'beitv2_base_patch16_224',\
- 'beitv2_large_patch16_224',\
- 'xcit_large_24_p8_224',\
- 'deit3_small_patch16_384',\
- 'deit3_huge_patch14_224']
-
- def list_models():
- print('The currently supported models are as follows:\n')
- print(model_list)
-
- def create_model(
- model_name,
- pretrained=False,
- **kwargs):
- """Create a model
-
- Args:
- model_name (str): name of model to instantiate
- pretrained (bool): load pretrained ImageNet-1k weights if true
-
- """
- kwargs = {k: v for k, v in kwargs.items() if v is not None}
- if model_name in model_list:
- if model_name == 'conformer_overflow_dw':
- from models.conformer_overflow_dw import conformer_overflow_dw as ConformerOverflowDW
- net = ConformerOverflowDW(**kwargs)
- checkpoint_path = './pretrained_model/conformer_overflow_dw.ckpt'
- url_from_google_driver = ''
-
- if model_name == 'conformer_overflow':
- from models.conformer_overflow import ConformerOverflow
- net = ConformerOverflow(**kwargs)
- checkpoint_path = './pretrained_model/conformer_overflow_dw.ckpt'
- url_from_google_driver = ''
-
- if model_name == 'conformer':
- from models.conformer import Conformer
- net = Conformer(**kwargs)
- checkpoint_path = './pretrained_model/conformer.ckpt'
- url_from_google_driver = ''
-
- if model_name == 'mobilenetv3_large_100':
- from models.mobilenetV3 import mobilenet_v3_large as mobilenetv3
- net = mobilenetv3(**kwargs)
- checkpoint_path = './pretrained_model/mobilenetv3.ckpt'
- url_from_google_driver = 'https://drive.google.com/u/0/uc?id=1DEuPh6FMtuifxfn0Pj_I9Y-_UjMIyI2M&export=download'
-
- if model_name == 'resnet50':
- from models.resnet import resnet50 as resnet
- net = resnet(**kwargs)
- checkpoint_path = './pretrained_model/resnet50.ckpt'
- url_from_google_driver = ''
-
- if model_name == 'vit_base_patch32_224':
- from models.vit import get_network
- net = get_network(**kwargs)
- checkpoint_path = './pretrained_model/vit_base_patch32_224.ckpt'
- url_from_google_driver = ''
-
- if model_name == 'swin_tiny_patch4_window7_224':
- from models.swin_transformer import get_swinv1, CreateSwinv1
- from core.args_swinv1 import args
- net = CreateSwinv1(get_swinv1(args))
- checkpoint_path = './pretrained_model/swin_tiny_patch4_window7_224.ckpt'
- url_from_google_driver = 'https://drive.google.com/u/0/uc?id=1RSo02ptqnfr7p_1cfmJA8etfUtmZ-Yxl&export=download'
-
- if model_name == 'swinv2_tiny_patch4_window8_256':
- from models.swin_transformer_v2 import get_swinv2, CreateSwinv2
- from core.args_swinv2 import args
- net = CreateSwinv2(get_swinv2(args))
- checkpoint_path = './pretrained_model/swinv2_tiny_patch4_window8_256.ckpt'
- url_from_google_driver = 'https://drive.google.com/u/0/uc?id=1kjf2cW6g0_n2SjaC3halY4kFgSYI3L__&export=download'
-
- if model_name == 'beit_base_patch16_224':
- from models.beit import beit_base_patch16_224
- net = net = beit_base_patch16_224()
- checkpoint_path = './pretrained_model/beit_base_patch16_224.ckpt'
- url_from_google_driver = 'https://drive.google.com/u/0/uc?id=1-Cr1KQD7AR6I8TKVeHvOlhgBgvTR-0Us&export=download'
-
- if model_name == 'beit_base_patch16_384':
- from models.beit import beit_base_patch16_384
- net = net = beit_base_patch16_384()
- checkpoint_path = './pretrained_model/beit_base_patch16_384.ckpt'
- url_from_google_driver = 'https://drive.google.com/u/0/uc?id=1-Kqh28JaUBlDO5mYVt97a0nVcooVghWy&export=download'
-
- if model_name == 'beit_large_patch16_224':
- from models.beit import beit_large_patch16_224
- net = net = beit_large_patch16_224()
- checkpoint_path = './pretrained_model/beit_large_patch16_224.ckpt'
- url_from_google_driver = 'https://drive.google.com/u/0/uc?id=1-Q0Y86yPORn_xnnbbojjRqxp_RFPyz1n&export=download'
-
- if model_name == 'beit_large_patch16_384':
- from models.beit import beit_large_patch16_384
- net = net = beit_large_patch16_384()
- checkpoint_path = './pretrained_model/beit_large_patch16_384.ckpt'
- url_from_google_driver = 'https://drive.google.com/u/0/uc?id=1-Kqh28JaUBlDO5mYVt97a0nVcooVghWy&export=download'
-
- if model_name == 'beit_large_patch16_512':
- from models.beit import beit_large_patch16_512
- net = net = beit_large_patch16_512()
- checkpoint_path = './pretrained_model/beit_large_patch16_512.ckpt'
- url_from_google_driver = 'https://drive.google.com/u/0/uc?id=1-A9XNXNo0vZsAGKUGPErAcvSV1cRZrtB&export=download'
-
- if model_name == 'beitv2_base_patch16_224':
- from models.beit import beitv2_base_patch16_224
- net = net = beitv2_base_patch16_224()
- checkpoint_path = './pretrained_model/beitv2_base_patch16_224.ckpt'
- url_from_google_driver = 'https://drive.google.com/u/0/uc?id=1-SP4OdbFK9bBYg5fbX8sYOizhuMwnPCY&export=download'
-
- if model_name == 'beitv2_large_patch16_224':
- from models.beit import beitv2_large_patch16_224
- net = net = beitv2_large_patch16_224()
- checkpoint_path = './pretrained_model/beitv2_large_patch16_224.ckpt'
- url_from_google_driver = 'https://drive.google.com/u/0/uc?id=1-ZTQvQJ5TiJXbpyMEFR-0vLC-qLQKinM&export=download'
-
- if model_name == 'xcit_large_24_p8_224':
- from models.xcit import get_xcit_large_24_p8_224
- net = get_xcit_large_24_p8_224()
- checkpoint_path = './pretrained_model/xcit_large_24_p8_224.ckpt'
- url_from_google_driver = 'https://drive.google.com/u/0/uc?id=1-28WfRpLcDrhUoxUm9YNYeri3HaqcMge&export=download'
-
- if model_name == 'deit3_small_patch16_384':
- from models.deit3 import get_deit3_small_patch16_384
- net = get_deit3_small_patch16_384()
- checkpoint_path = './pretrained_model/deit3_small_patch16_384_1k.ckpt'
- url_from_google_driver = 'https://drive.google.com/u/0/uc?id=1-2qJI-QV9kKDNFYILgHiTXRMf3SG8cTF&export=download'
-
- if model_name == 'deit3_huge_patch14_224':
- from models.deit3 import get_deit3_huge_patch14_224
- net = get_deit3_huge_patch14_224()
- checkpoint_path = './pretrained_model/deit3_huge_patch14_224.ckpt'
- url_from_google_driver = 'https://drive.google.com/u/0/uc?id=1-9LmuBzU1RZDvwmq-vkFKbG5aZpP6YkZ&export=download'
-
- else:
- print('check model\'s name and retry')
-
- if pretrained:
- if not os.path.exists(checkpoint_path):
- print('downloading from google driver')
- try:
- if not os.path.exists(r'./pretrained_model'):
- os.mkdir(r'./pretrained_model')
- wget.download(url_from_google_driver, checkpoint_path)
- print('download success')
- except:
- print('download error')
- try:
- param_dict = ms.load_checkpoint(checkpoint_path)
- ms.load_param_into_net(net, param_dict)
- except:
- print(f'please download model from {url_from_google_driver} and store files in {checkpoint_path}')
-
- return net
|