Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
youlz 3172f2d980 | 1 year ago | |
---|---|---|
model_utils | 1 year ago | |
scripts | 1 year ago | |
src | 1 year ago | |
README.md | 1 year ago | |
__init__.py | 1 year ago | |
converter.py | 1 year ago | |
default_config.yaml | 1 year ago | |
eval.py | 1 year ago | |
export.py | 1 year ago | |
requirements.txt | 1 year ago | |
train.py | 1 year ago | |
vgg16_feat_extr_ms.ckpt | 1 year ago |
深度生成方法最近通过引入结构先验在图像修复方面取得了长足的进步。然而,由于在结构重建过程中缺乏与图像纹理的适当交互,目前的解决方案在处理大腐败的情况时能力不足,并且通常会导致结果失真。CTSDG 是一种新颖的用于图像修复的双流网络,它以耦合的方式对结构约束的纹理合成和纹理边缘引导结构重建进行建模,使它们更好地相互利用,以获得更合理的生成。此外,为了增强全局一致性,设计了双向门控特征融合( Bi-GFF )模块来交换和结合结构和纹理信息,并开发了上下文特征聚合( CFA )模块,通过区域亲和学习和多尺度特征聚合来细化生成的内容。
用到的数据集:
需要从 CELEBA 下载以下内容:
img_align_celeba.zip
list_eval_partitions.txt
需要从 NVIDIA Irregular Mask Dataset 下载以下内容:
irregular_mask.zip
test_mask.zip
目录结构如下:
.
├── img_align_celeba # 图像文件夹
├── irregular_mask # 用于训练的遮罩
│ └── disocclusion_img_mask
├── mask # 用于测试的遮罩
│ └── testing_mask_dataset
└── list_eval_partition.txt # 拆分文件
https://git.openi.org.cn/youlz/CTSDG
日志文件保存在 log.zip 中
.
├── converter.py # 将 VGG16 转换为 mindspore 的 checkpoint
├── dataset
│ ├── img_align_celeba # celeba 图像文件夹
│ ├── irregular_mask # 用于训练的遮罩
│ ├── list_eval_partition.txt # 拆分文件
│ └── mask # 用于测试的遮罩
├── default_config.yaml # 默认配置文件
├── eval.py # 评估 mindspore 模型
├── __init__.py # 初始化文件
├── model_utils
│ ├── config.py # 语法参数
│ └── __init__.py # 初始化文件
├── requirements.txt
├── scripts
│ ├── run_eval_npu.sh # 在 NPU 上启动评估的脚本
│ └── run_train_npu.sh # 在 NPU 上启动训练的脚本
├── src
│ ├── callbacks.py # 回调
│ ├── dataset.py # celeba 数据集
│ ├── discriminator # 鉴别器
│ ├── generator # 生成器
│ ├── initializer.py # 初始化器权重
│ ├── __init__.py # 初始化文件
│ ├── losses.py # 模型 loss
│ ├── trainer.py # ctsdg模型的训练者
│ └── utils.py # 工具
├── train.py # 训练 mindspore 模型
└── vgg16-397923af.pth # VGG16 torch 模型
可以在 default_config.yaml
中配置训练参数
"gen_lr_train": 0.0002, # 生成器训练的 lr
"gen_lr_finetune": 0.00005, # 生成器微调的 lr
"dis_lr_multiplier": 0.1, # 判别器的 lr 是生成器的 lr 乘以这个参数
"batch_size": 6, # batch size
"train_iter": 350000, # 训练迭代次数
"finetune_iter": 150000 # 微调迭代次数
"image_load_size": [256, 256] # 输入图像大小
有关更多参数,请参见 default_config.yaml
的内容。
train_iter : 350000
finetune_iter : 150000
gen_lr_train : 0.0002
gen_lr_finetune : 0.00005
dis_lr_multiplier : 0.1
batch_size : 6
Loss function : GWithLossCell() , DWithLossCell()
Optimizer : Adam
对于训练 CTSDG 模型,需要对 VGG16 torch 模型进行感知损失转换。
python converter.py --torch_pretrained_vgg=/path/to/torch_pretrained_vgg
转换后的 mindpore checkpoint 将保存在与 torch 模型相同的目录中,名称为vgg16_feat_extr_ms.ckpt
。
After preparing the dataset and converting VGG16 you can start training and evaluation as follows:
准备好数据集同时完成 VGG16 的转换后,就可以通过如下步骤开始训练和评估模型了。
# train
bash scripts/run_train_npu.sh [DEVICE_ID] [CFG_PATH] [SAVE_PATH] [VGG_PRETRAIN] [IMAGES_PATH] [MASKS_PATH] [ANNO_PATH]
Example:
# DEVICE_ID - 用于训练的设备 ID 号
# CFG_PATH - config 的路径
# SAVE_PATH - 保留 logs and checkpoints 的路径
# VGG_PRETRAIN - 预训练 VGG16 的路径
# IMAGES_PATH - CELEBA 数据集的路径
# MASKS_PATH - 用于训练的遮罩路径
# ANNO_PATH - 拆分文件的路径
bash scripts/run_train_npu.sh 0 ./default_config.yaml /path/to/output /path/to/vgg16_feat_extr.ckpt /path/to/img_align_celeba /path/to/training_mask /path/to/list_eval_partitions.txt
# evaluate
bash scripts/run_eval_npu.sh [DEVICE_ID] [CFG_PATH] [CKPT_PATH] [IMAGES_PATH] [MASKS_PATH] [ANNO_PATH]
Example:
# evaluate
# DEVICE_ID - 用于评估的设备 ID 号
# CFG_PATH - config 的路径
# CKPT_PATH - path to ckpt for evaluation用于评估的 ckpt 的路径
# IMAGES_PATH - CELEBA 数据集的路径
# MASKS_PATH - 用于测试的遮罩路径
# ANNO_PATH - 拆分文件的路径
bash scripts/run_eval_npu.sh 0 ./default_config.yaml /path/to/ckpt /path/to/img_align_celeba /path/to/testing_mask /path/to/list_eval_partitions.txt
评估日志文件储存在 ./logs/eval_log.txt
.
结果:
PSNR:
0-20%: 37.93
20-40%: 29.35
40-60%: 24.23
SSIM:
0-20%: 0.979
20-40%: 0.921
40-60%: 0.828
2022昇腾AI创新大赛昇思赛道 第二批 赛题四:利用MindSpore实现CTSDG图像修复网络
Python Shell Text
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》