Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
jeffding ce74a2ab6f | 1 year ago | |
---|---|---|
.. | ||
image | 1 year ago | |
src | 1 year ago | |
LICENSE | 1 year ago | |
README.md | 1 year ago | |
conv_pth2ckpt.py | 1 year ago | |
eval.py | 1 year ago | |
train.py | 1 year ago |
作者:邢朝龙 kaierlong@126.com
Swin Transformer V2
是微软团队在Swin Transformer(V1)
基础上提出的升级版网络结构。
在现有的视觉大模型中,主要存在几方面问题:
针对以上问题,Swin Transformer V2
主要提出三个改进点:
使用训练及测试数据集如下:
使用的数据集:ImageNet2012
数据集大小:共1000个类、224*224彩色图像
训练集:共1,281,167张图像
测试集:共50,000张图像
数据格式:JPEG
注:数据在dataset.py中处理。
下载数据集,目录结构如下:
└─dataset
├─train # 训练数据集
└─val # 评估数据集
仓库地址如下:https://github.com/kaierlong/Swin-Transformer-V2-Ascend
代码目录结构及说明如下:
.
├── README.md // 说明文档
├── README_CN.md // 中文说明文档
├── conv_pth2ckpt.py // 预训练权重转换
├── eval.py // 评估文件
├── image // 文档图片目录
├── src
│ ├── args.py
│ ├── configs // 模型参数配置目录
│ │ ├── parser.py
│ │ ├── swin_tiny_patch4_window7_224.yaml
│ │ ├── swinv2_base_patch4_window12to16_192to256_22kto1k_ft.ckpt // 预训练权重文件
│ │ ├── swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml
│ │ ├── swinv2_base_patch4_window8_256.yaml
│ │ ├── swinv2_large_patch4_window16_256.yaml
│ │ ├── swinv2_small_patch4_window8_256.yaml
│ │ └── swinv2_tiny_patch4_window8_256.yaml
│ ├── data // 数据加载及处理目录
│ │ ├── __init__.py
│ │ ├── augment
│ │ │ ├── __init__.py
│ │ │ ├── auto_augment.py
│ │ │ ├── custom_transforms.py
│ │ │ ├── mixup.py
│ │ │ └── random_erasing.py
│ │ ├── data_utils
│ │ │ ├── __init__.py
│ │ │ └── moxing_adapter.py
│ │ └── imagenet.py
│ ├── image22kto1k.txt // 22K转1K数据集ID映射表
│ ├── models // 模型定义目录
│ │ ├── __init__.py
│ │ └── swintransformer
│ │ ├── __init__.py
│ │ ├── clip_ops.py
│ │ ├── get_swin.py
│ │ ├── get_swin_v2.py
│ │ ├── misc.py
│ │ ├── swin_transformer.py
│ │ └── swin_transformer_v2.py // swin transformer v2定义文件
│ ├── tools // 相关工具目录
│ │ ├── __init__.py
│ │ ├── callback.py
│ │ ├── cell.py
│ │ ├── criterion.py
│ │ ├── get_misc.py
│ │ ├── optimizer.py
│ │ └── schedulers.py
│ └── trainers // 训练目录
│ ├── __init__.py
│ ├── model_ema.py
│ ├── train_one_step_with_ema.py
│ └── train_one_step_with_scale_and_clip_global_norm.py
└── train.py // 训练文件
软硬件环境如下:
超参数配置如下:
# Architecture
arch: swinv2_base_patch4_window12to16_192to256_22kto1k_ft
# ===== Dataset ===== #
data_url: ./data/imagenet
set: ImageNet
num_classes: 1000
mix_up: 0.8
cutmix: 1.0
auto_augment: rand-m9-mstd0.5-inc1
interpolation: bicubic
re_prob: 0.25
re_mode: pixel
re_count: 1
mixup_prob: 1.
switch_prob: 0.5
mixup_mode: batch
crop_ratio: 0.875
# ===== Learning Rate Policy ======== #
optimizer: adamw
lr_scheduler: cosine_lr
base_lr: 0.00005
min_lr: 0.0000002
warmup_length: 5
warmup_lr: 0.00000002
cool_length: 10
cool_lr: 0.0000002
nonlinearity: GELU
# ===== Network training config ===== #
amp_level: O1
keep_bn_fp32: True
beta: [ 0.9, 0.999 ]
is_dynamic_loss_scale: True
use_global_norm: True
clip_global_norm_value: 5.
enable_ema: False
ema_decay: 0.9999
loss_scale: 1024
weight_decay: 0.00000001
momentum: 0.9
label_smoothing: 0.1
epochs: 40
batch_size: 32
# ===== Hardware setup ===== #
num_parallel_workers: 32
device_target: Ascend
# ===== Model config ===== #
drop_path_rate: 0.2
embed_dim: 128
depths: [ 2, 2, 18, 2 ]
num_heads: [ 4, 8, 16, 32 ]
window_size: 16
image_size: 256
pretrained_window_sizes: [ 12, 12, 12, 6 ]
说明:
因为需要用到预训练模型,需要将pytorch模型进行转换,转换命令如下:
提前下载pytorch模型:
# 友情提示需要用到pytorch环境 python3 conv_pth2ckpt.py --pth_file=swinv2_base_patch4_window12_192_22k.pth --ckpt_file=src/configs/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.ckpt src/configs/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.ckpt --cls_map_file=src/image22kto1k.txt
训练命令:
python3 train.py --run_openi=True --arch=swinv2_base_patch4_window12to16_192to256_22kto1k_ft --pretrained=swin --device_num=8
推理命令:
python3 eval.py --config=src/configs/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml --pretrained={ckpt_path} --device_id={device_id} --device_target={device_target} --data_url={data_url}
Huwei Ascend 910 Code
Text Unity3D Asset nesC Python Jupyter Notebook other
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》