Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
xuyang baa4819ac2 | 1 year ago | |
---|---|---|
scripts | 1 year ago | |
src | 1 year ago | |
README.md | 1 year ago | |
export.py | 1 year ago | |
icdar_config.yaml | 1 year ago | |
infer.py | 1 year ago | |
train_dist.py | 1 year ago |
Mask TextSpotter是2018年提出的一种端到端的文本检测模型,该模型具有简单、流畅的训练方案提出的方法,可以检测和识别各种形状的文本,包括水平、定向和弯曲的文本。
[论文] http://xxx.itp.ac.cn/pdf/1807.02242v2 Lyu P , Liao M , Yao C , et al. European Conference on Computer Vision Springer, Cham, 2018.
模型包含四个组件:以基于ResNet-50的特征金字塔网络(FPN)为骨干,用于生成文本候选的区域候选网络(RPN),Fast R-CNN对于边界框回归,用于文本实例分割和字符分割的mask分支
使用的数据集:[ICDAR2013]
数据集大小:252MB
数据格式:JPEG
使用的数据集:[ICDAR2015]
使用的数据集:[scut-eng-char]
使用的数据集:[Total-Text]
使用的数据集:[SynthText]
采用混合精度的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
Ascend处理器环境运行
# 添加数据集(启智平台添加 https://git.openi.org.cn/OpenModelZoo/masktextspotter/datasets)
# 推理前添加checkpoint路径参数
--chcekpoint_path:'...'
# 运行训练示例
python train_dist.py > train.log 2>&1 &
# 运行分布式训练示例
bash scripts/dist_train.sh
# 运行评估示例
python infer.py --checkpoint_path=[MODEL_PATH] --icdar_root=[DATA_PATH]
对于分布式训练,需要提前创建JSON格式的hccl配置文件。
请遵循以下链接中的说明:
https://gitee.com/mindspore/models/tree/master/utils/hccl_tools.
默认使用5个数据集。您也可以将$dataset_type
传入脚本,以便选择其他数据集。如需查看更多详情,请参考指定脚本。
在 ModelArts 进行训练 (如果你想在modelarts上运行,可以参考以下文档 modelarts)
在 ModelArts 上使用8卡训练5个数据集
# (1) 在网页上设置 "config_path='/path_to_code/config.yaml'"
# (2) 执行a或者b
# a. 在 imagenet_config.yaml 文件中设置 "enable_modelarts=True"
# 在 imagenet_config.yaml 文件中设置 "dataset_name='all'"
# 在 imagenet_config.yaml 文件中设置 "train_data_path='/cache/data/all/train/'"
# 在 imagenet_config.yaml 文件中设置 其他参数
# b. 在网页上设置 "enable_modelarts=True"
# 在网页上设置 "dataset_name=all"
# 在网页上设置 "train_data_path=/cache/data/all/train/"
# 在网页上设置 其他参数
# (3) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
# (4) 在网页上设置你的代码路径为 "/path/masktextspotter"
# (5) 在网页上设置启动文件为 "train.py"
# (6) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
# (7) 创建训练作业
在 ModelArts 上使用单卡验证 ICDAR2013 数据集
# (1) 在网页上设置 "config_path='/path_to_code/imagenet_config.yaml'"
# (2) 执行a或者b
# a. 在 imagenet_config.yaml 文件中设置 "enable_modelarts=True"
# 在 imagenet_config.yaml 文件中设置 "dataset_name='icdar2013'"
# 在 imagenet_config.yaml 文件中设置 "val_data_path='/cache/data/ICDAR2013/val/'"
# 在 imagenet_config.yaml 文件中设置 "checkpoint_url='s3://dir_to_trained_ckpt/'"
# 在 imagenet_config.yaml 文件中设置 "checkpoint_path='/cache/checkpoint_path/model.ckpt'"
# 在 imagenet_config.yaml 文件中设置 其他参数
# b. 在网页上设置 "enable_modelarts=True"
# 在网页上设置 "dataset_name=icdar2013"
# 在网页上设置 "val_data_path=/cache/data/ICDAR2013/val/"
# 在网页上设置 "checkpoint_url='s3://dir_to_trained_ckpt/'"
# 在网页上设置 "checkpoint_path='/cache/checkpoint_path/model.ckpt'"
# 在网页上设置 其他参数
# (3) 上传你的预训练模型到 S3 桶上
# (4) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
# (5) 在网页上设置你的代码路径为 "/path/masktextspotter"
# (6) 在网页上设置启动文件为 "eval.py"
# (7) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
# (8) 创建训练作业
在 ModelArts 上使用单卡导出 ICDAR2013 数据集
# (1) 在网页上设置 "config_path='/path_to_code/icdar2013_config.yaml'"
# (2) 执行a或者b
# a. 在 icdar2013_config.yaml 文件中设置 "enable_modelarts=True"
# 在 icdar2013_config.yaml 文件中设置 "checkpoint_url='s3://dir_to_trained_ckpt/'"
# 在 icdar2013_config.yaml 文件中设置 "ckpt_file='/cache/checkpoint_path/model.ckpt'"
# 在 icdar2013_config.yaml 文件中设置 其他参数
# b. 在网页上设置 "enable_modelarts=True"
# 在网页上设置 "checkpoint_url=s3://dir_to_trained_ckpt/" on the website UI interface.
# 在网页上设置 "ckpt_file=/cache/checkpoint_path/model.ckpt" on the website UI interface.
# 在网页上设置 其他参数
# (3) 上传你的预训练模型到 S3 桶上
# (5) 在网页上设置你的代码路径为 "/path/masktextspotter"
# (6) 在网页上设置启动文件为 "export.py"
# (7) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
# (8) 创建训练作业
├── model_zoo
├── README.md // 所有模型相关说明
├── masktextspotter
├── README.md // masktextspotter相关说明
├── scripts
│ ├──run_train.sh // 分布式到Ascend的shell脚本
│ ├──run_train_gpu.sh // 分布式到GPU处理器的shell脚本
│ ├──run_train_cpu.sh // CPU处理器训练的shell脚本
│ ├──run_eval.sh // Ascend评估的shell脚本
│ ├──run_infer_310.sh // Ascend推理shell脚本
│ ├──run_eval_gpu.sh // GPU处理器评估的shell脚本
│ ├──run_eval_cpu.sh // CPU处理器评估的shell脚本
├── src
│ ├──dataset.py // 创建数据集
│ ├──mask_rcnn_r50.py //masktextspotter架构
│ ├──config.py // 参数配置
├── train.py // 训练脚本
├── eval.py // 评估脚本
├── postprogress.py // 310推理后处理脚本
├── export.py // 将checkpoint文件导出到air/mindir
在config.py中可以同时配置训练参数和评估参数。
配置masktextspotter和数据集。
'pre_trained':'False' # 是否基于预训练模型训练
'nump_classes':10 # 数据集类数
'lr_init':0.1 # 初始学习率
'batch_size':128 # 训练批次大小
'epoch_size':125 # 总计训练epoch数
'momentum':0.9 # 动量
'weight_decay':5e-4 # 权重衰减值
'image_height':224 # 输入到模型的图像高度
'image_width':224 # 输入到模型的图像宽度
'data_path':'./all' # 训练和评估数据集的绝对全路径
'device_target':'Ascend' # 运行设备
'device_id':4 # 用于训练或评估数据集的设备ID使用run_train.sh进行分布式训练时可以忽略。
'keep_checkpoint_max':10 # 最多保存checkpoint文件的数量
'checkpoint_path':'./train_masktextspotter_all-125_390.ckpt' # checkpoint文件保存的绝对全路径
'onnx_filename':'masktextspotter.onnx' # export.py中使用的onnx模型文件名
'geir_filename':'masktextspotter.geir' # export.py中使用的geir模型文件名
更多配置细节请参考脚本config.py
。
Ascend处理器环境运行
python train.py > train.log 2>&1 &
上述python命令将在后台运行,您可以通过train.log文件查看结果。
训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式达到损失值:
# grep "loss is " train.log
epoch:1 step:390, loss is 1.4842823
epcoh:2 step:390, loss is 1.0897788
...
模型检查点保存在当前目录下。
GPU处理器环境运行
export CUDA_VISIBLE_DEVICES=0
python train.py > train.log 2>&1 &
上述python命令将在后台运行,您可以通过train.log文件查看结果。
训练结束后,您可在默认./ckpt_0/
脚本文件夹下找到检查点文件。
CPU处理器环境运行
nohup python train.py --config_path=all_config_cpu.yaml --dataset_name=all > train.log 2>&1 &
上述python命令将在后台运行,您可以通过train.log文件查看结果。
训练结束后,您可在yaml文件中配置的文件夹下找到检查点文件。
Ascend处理器环境运行
bash run_distribute_train.sh ~/hccl_8p.json all
上述shell脚本将在后台运行分布训练。您可以通过train_parallel[X]/log文件查看结果。采用以下方式达到损失值:
# grep "result:" train_parallel*/log
train_parallel0/log:epoch:1 step:48, loss is 1.4302931
train_parallel0/log:epcoh:2 step:48, loss is 1.4023874
...
train_parallel1/log:epoch:1 step:48, loss is 1.3458025
train_parallel1/log:epcoh:2 step:48, loss is 1.3729336
...
...
GPU处理器环境运行
bash scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
上述shell脚本将在后台运行分布训练。您可以通过train/train.log文件查看结果。
在Ascend环境运行时评估ICDAR-2013数据集
在运行以下命令之前,请检查用于评估的检查点路径。请将检查点路径设置为绝对全路径,例如“username/masktextspotter/train_masktextspotter_all-125_390.ckpt”。
python eval.py > eval.log 2>&1 &
OR
bash run_eval.sh icdar2013
上述python命令将在后台运行,您可以通过eval.log文件查看结果。测试数据集的准确性如下:
# grep "accuracy:" eval.log
accuracy:{'acc':0.934}
注:对于分布式训练后评估,请将checkpoint_path设置为最后保存的检查点文件,如“username/masktextspotter/train_parallel0/train_masktextspotter_all-125_48.ckpt”。测试数据集的准确性如下:
# grep "accuracy:" dist.eval.log
accuracy:{'acc':0.9217}
在GPU处理器环境运行时评估ICDAR-2013数据集
在运行以下命令之前,请检查用于评估的检查点路径。请将检查点路径设置为绝对全路径,例如“username/masktextspotter/train/ckpt_0/train_masktextspotter_all-125_390.ckpt”。
python eval.py --checkpoint_path=[CHECKPOINT_PATH] > eval.log 2>&1 &
上述python命令将在后台运行,您可以通过eval.log文件查看结果。测试数据集的准确性如下:
# grep "accuracy:" eval.log
accuracy:{'acc':0.930}
或者,
bash run_eval_gpu.sh [CHECKPOINT_PATH]
上述python命令将在后台运行,您可以通过eval/eval.log文件查看结果。测试数据集的准确性如下:
# grep "accuracy:" eval/eval.log
accuracy:{'acc':0.930}
在导出之前需要修改数据集对应的配置文件,ICDAR-2013的配置文件为icdar2013_config.yaml.
需要修改的配置项为 batch_size 和 ckpt_file.
python export.py --config_path [CONFIG_PATH]
在还行推理之前我们需要先导出模型。Air模型只能在昇腾910环境上导出,mindir可以在任意环境上导出。batch_size只支持1。
在昇腾310上使用ICDAR-2013数据集进行推理
在执行下面的命令之前,我们需要先修改icdar2013的配置文件。修改的项包括batch_size和val_data_path。LABEL_FILE参数只对ImageNet数据集有用,可以传任意值。
推理的结果保存在当前目录下,在acc.log日志文件中可以找到类似以下的结果。
# Ascend310 inference
bash run_infer_310.sh [MINDIR_PATH] [DATASET] [DATA_PATH] [LABEL_FILE] [DEVICE_ID]
after allreduce eval: top1_correct=9252, tot=10000, acc=92.52%
参数 | Ascend | GPU |
---|---|---|
模型版本 | MASKTEXTSPOTTER | MASKTEXTSPOTTER |
资源 | Ascend 910;CPU 2.60GHz,192核;内存 755G;系统 Euler2.8 | NV SMX2 V100-32G |
上传日期 | 2022-09-13 | 2023-09-13 |
MindSpore版本 | 1.5.1 | 1.5.1 |
数据集 | 5个训练集 | 5个训练集 |
训练参数 | epoch=125, steps=390, batch_size = 128, lr=0.02 | epoch=125, steps=390, batch_size=128, lr=0.02 |
优化器 | Momentum | Momentum |
损失函数 | Softmax交叉熵 | Softmax交叉熵 |
输出 | 概率 | 概率 |
损失 | 0.0016 | 0.0016 |
速度 | 单卡:554毫秒/步; 8卡:350毫秒/步 | 单卡:150毫秒/步; 8卡:164毫秒/步 |
总时长 | 单卡:63.85分钟; 8卡:11.28分钟 | 单卡:126.87分钟; 8卡:21.65分钟 |
参数(M) | 13.0 | 13.0 |
微调检查点 | 43.07M (.ckpt文件) | 43.07M (.ckpt文件) |
推理模型 | 21.50M (.onnx文件), 21.60M(.air文件) | |
脚本 | masktextspotter脚本 | masktextspotter脚本 |
参数 | Ascend | GPU |
---|---|---|
模型版本 | MASKTEXTSPOTTER | MASKTEXTSPOTTER |
资源 | Ascend 910;系统 Euler2.8 | GPU |
上传日期 | 2022-09-13 | 2022-09-13 |
MindSpore 版本 | 1.5.1 | 1.5.1 |
数据集 | ICDAR-2013, 233张图像 | ICDAR-2013, 233张图像 |
batch_size | 1 | 1 |
输出 | 概率 | 概率 |
准确性 | 单卡: 69.4%; 8卡:69.7% | 单卡:95.1%, 8卡:95.8% |
推理模型 | 21.50M (.onnx文件) |
如果您需要使用此训练模型在GPU、Ascend 910、Ascend 310等多个硬件平台上进行推理,可参考此链接。下面是操作步骤示例:
Ascend处理器环境运行
# 设置上下文
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend', device_id=1)
# 加载未知数据集进行推理
ds = create_maskrcnn_dataset(mindrecord_file, batch_size=config.test_batch_size, is_training=False)
dataset_size = ds.get_dataset_size()
# 定义模型
net = MaskTextSpotter_Resnet50(config=config)
# 加载预训练模型
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# 对未知数据集进行预测
acc = model.eval(dataset)
print("accuracy:", acc)
GPU处理器环境运行
# 设置上下文
context.set_context(mode=context.GRAPH_HOME, device_target="GPU")
# 加载未知数据集进行推理
dataset = dataset.create_dataset(cfg.data_path, 1, False)
# 定义模型
net = masktextspotter(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
cfg.momentum, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean',
is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
# 加载预训练模型
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# Make predictions on the unseen dataset
acc = model.eval(dataset)
print("accuracy:", acc)
Ascend处理器环境运行
# 加载数据集
ds = create_maskrcnn_dataset(mindrecord_file, batch_size=config.batch_size, is_training=True)
dataset_size = ds.get_dataset_size()
# 定义模型
net = MaskTextSpotter_Resnet50(config=config)
net.set_train(True)
# 设置回调
loss = LossNet()
net = WithLossCell(net, loss)
net = nn.TrainOneStepCell(net, optim_sgd)
model = Model(net)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
# 开始训练
model.train(config.epoch_size, ds, callbacks=cb, sink_size=dataset_size, dataset_sink_mode=False)
print("train success")
GPU处理器环境运行
# 加载数据集
dataset = create_dataset(cfg.data_path, 1)
batch_num = dataset.get_dataset_size()
# 定义模型
net = masktextspotter(num_classes=cfg.num_classes)
# 若pre_trained为True,继续训练
if cfg.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
# 设置回调
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num)
ckpoint_cb = ModelCheckpoint(prefix="train_masktextspotter_icdar2013", directory="./ckpt_" + str(get_rank()) + "/",
config=config_ck)
loss_cb = LossMonitor()
# 开始训练
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success")
在dataset.py中,我们设置了“create_dataset”函数内的种子,同时还使用了train.py中的随机种子。
请浏览官网主页。
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》