Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
15534081591 841622628d | 1 year ago | |
---|---|---|
ckpt | 1 year ago | |
pictures | 1 year ago | |
scripts | 1 year ago | |
src | 1 year ago | |
README.md | 1 year ago | |
export.py | 1 year ago | |
spade_eval.py | 1 year ago | |
spade_train.py | 1 year ago |
SPADE是2019年提出的语义图像合成算法,该论文发表在CVPR2019上面,该算法提出了空间自适应归一化,用于在给定输入语义布局的情况下合成照片级真实感图像.
论文:Taesung Park, Ming-Yu Liu, Ting-Chun Wang, and Jun-Yan Zhu. "Semantic Image Synthesis with Spatially-Adaptive Normalization.g". Presented at CVPR 2019.
SPADE模型采用生成式对抗网络作为网络主干。其中,生成器由一系列带有最近邻上采样的SPADE ResBlks组成,判别器的体系结构遵循pix2pixHD方法中使用的体系结构,该方法使用多尺度设计的InstanceNorm。唯一的区别是SPADE将谱归一化应用于所有的卷积层.
使用的数据集:[ADE20K]
官网链接
ADE20K http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
训练开始前需获取vgg的预训练模型(pth) 官方下载链接: https://download.pytorch.org/models/vgg19-dcbb9e9d.pth
将获取到的vgg19-dcbb9e9d.pth转换成vgg.ckpt
# 转换vgg.pth脚本,会在运行目录下生成一个vgg.ckpt
python ./src/utils/pth2ckpt.py [pth_path] 'vgg'
#Ascend多卡训练
bash run_distribute_train.sh [DEVICE_NUM] [VGG_CKPT_PATH] [DATAROOT] [RANK_TABLE_PATH]
"""
[DEVICE_NUM]:卡的数量
[VGG_CKPT_PATH]:vgg预训练模型的绝对路径
[DATAROOT]:ADE20K数据集的绝对路径
[RANK_TABLE_PATH]: rank_table文件绝对路径
"""
#Ascend单卡训练
bash run_standalone_train.sh [DEVICE_ID] [VGG_CKPT_PATH] [DATA_PATH]
"""
[DEVICE_ID]:卡的编号
[VGG_CKPT_PATH]:vgg预训练模型的绝对路径
[DATA_PATH]:ADE20K数据集的绝对路径
"""
#Ascend单卡测试
bash run_eval.sh [DEVICE_ID] [CKPT_PATH] [DATA_PATH] [RESULT_PATH] [INCEPTION_CKPT_PATH]
"""
[DEVICE_ID]:卡的编号
[CKPT_PATH]:训练好的ckpt绝对路径
[DATA_PATH]:ADE20K数据集的绝对路径
[RESULT_PATH]:推理结果路径
[INCEPTION_CKPT_PATH]:inception网络预训练模型ckpt绝对路径
"""
Ascend训练:生成RANK_TABLE_FILE
├── SPADE
├── scripts
│ ├──run_distribute_train.sh // 在Ascend中多卡训练
│ ├──run_distribute_test.sh // 在Ascend中单卡测试
│ ├──run_standalone_train.sh // 在Ascend中单卡训练
├── src //源码
│ │ ├── data
│ │ │ ├──ade20k_dataset.py
│ │ │ ├──image_folder.py
│ │ ├── models
│ │ │ ├──netG.py //生成器网络结构
│ │ │ ├──netD.py //判别器网络结构
│ │ │ ├──normalization.py //自定义正则化
│ │ │ ├──architecture.py
│ │ │ ├──cells.py //loss网络wrapper
│ │ │ ├──inception.py //FID推理网络结构
│ │ │ ├──init_Parameter.py //参数初始化
│ │ │ ├──loss.py //损失函数定义
│ │ │ ├──vgg.py //损失函数网络结构
│ │ ├── utils
│ │ │ ├──adam.py //自定义优化器
│ │ │ ├──lr_schedule.py //自定义学习率策略
│ │ │ ├──update_weight.py //谱归一化实现
│ │ │ ├──util.py //图片处理工具
│ │ │ ├──util.py //图片处理工具
│ │ │ ├──eval_fid.py //精度计算工具
│ │ │ ├──visualizer.py //可视化结果
├── README.md // SPADE相关说明
├── spade_train.py // 训练入口
├── spade_eval.py // 推理入口
├── export.py // 模型导出
spade_train.py
--data_url: obs桶数据集位置
--train_url: 输出文件路径
--batchSize: 输入的batch大小
--dataroot:数据集根目录
--vgg_ckpt_path:vgg ckpt路径
--decay_epoch: 学习率变化的起始epoch
--total_epoch: 学习率变化的最终epoch
--G_lr:生成器起始学习率
--D_lr:判别器起始学习率
--id: 使用的物理卡号
--distribute: 多卡运行
--run_modelarts: ModelArts上运行,默认为False
spade_eval.py
--ckpt_dir: 权重文件路径
--results_dir:运行结果保存的路径
--id: 使用的物理卡号
--fid_eval_ckpt_dir: inception网络预训练参数
Ascend处理器环境运行
python spade_train.py --id device_id --vgg_ckpt_path ./vgg.ckpt --dataroot ./ADEChallengeData2016 ./rank_table_8pcs.json
# 或进入脚本目录,执行脚本,请使用绝对路径
bash run_distribute_train.sh [DEVICE_NUM] [VGG_CKPT_PATH] [DATAROOT] [RANK_TABLE_PATH]
# 请使用绝对路径
经过训练后,损失值如下:
[199/200][1262/1263]: Loss_D: 0.245605 Loss_G: 26.058798
[199/200][1263/1263]: Loss_D: 0.571079 Loss_G: 18.928753
[199/200][1263/1263]: Loss_D: 0.383772 Loss_G: 24.714558
[199/200][1263/1263]: Loss_D: 0.470536 Loss_G: 22.845257
[199/200][1263/1263]: Loss_D: 0.284949 Loss_G: 25.437853
[199/200][1263/1263]: Loss_D: 0.240993 Loss_G: 25.671362
[199/200][1263/1263]: Loss_D: 0.550708 Loss_G: 21.344685
[199/200][1263/1263]: Loss_D: 0.132908 Loss_G: 29.253286
[199/200][1263/1263]: Loss_D: 0.093034 Loss_G: 27.286970
评估开始前需获取inception的预训练模型(pth) 官方下载链接: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
将获取到的pt_inception-2015-12-05-6726825d.pth转换成inception_pid.ckpt
Ascend处理器环境运行推理
# 运行`spade.run.py`生成inception_pid.ckpt。
python ./src/utils/pth2ckpt.py [pth_path] 'inception'
# 进入脚本目录,根据ADE20K数据集images和annotaions文件夹下的validation文件夹生成预测文件并得到fid推理精度。
python spade_eval.py --id 0 --dataroot ./ADEChallengeData2016 --results_dir ./results --ckpt_dir ./netG_epoch_200.ckpt --fid_eval_ckpt_dir ./inception_pid.ckpt
# 或进入脚本目录,执行脚本
bash run_eval.sh [DEVICE_ID] [CKPT_PATH] [DATA_PATH] [RESULT_PATH] [INCEPTION_CKPT_PATH]
测试数据集的准确率如下:
实际测试的FID精度为39.90053572672258
python export.py --ckpt_dir [CKPT_PATH] --file_format [FILE_FORMAT]
file_format
必须在 ["AIR", "MINDIR"]中选择。
ckpt_dir
ckpt存放路径
脚本会在当前目录下生成对应的AIR文件。
参数 | ModelArts |
---|---|
资源 | Ascend 910;CPU 2.60GHz, 192核;内存:755G |
上传日期 | 2021-08-12 |
MindSpore版本 | 1.3.0 |
数据集 | ADE20k |
训练参数 | epoch=200, batch_size=2, D_lr=0.0004,G_lr=0.0002 |
损失函数 | L1Loss,vgg |
损失 | g_loss:14左右,d_loss:0.2左右 |
速度 | 450毫秒/步 |
总时间 | 35小时 |
微调检查点 | 368.62 MB (.ckpt文件) |
无
请浏览官网主页
No Description
Python Shell
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》