|
- """
- usage:
- python train.py -p ./exp/exp-1 -l 10 -s 0.01 -a 0.01 -g 200 -t 0 -i 15 -b 4 -e 120 --debug
- """
- import argparse
- import os
- import sys
- import u_d.update_u_d as update_u_d
-
- import torch
- import torch.backends.cudnn as cudnn
- import random
- import numpy as np
- sys.path.append('./')
-
- os.environ["CUDA_VISIBLE_DEVICES"] = "1" #选择gpu0,当不写这个语句时,默认选择为gpu0,若想选择gpu1,则把0改为1,若此时仍然默认选择0,则把该语句放到import os下一行,往前面放
-
-
- #命令行参数定义
- # 步骤:(1)导入模块 import argparse
- # (2)实例化对象 parser = argparse.ArgumentParser()
- # (3)定义参数 涉及全称、简称、默认值default、备选choices=[]、变量类型要求type、执行动作action、(其中store_true表示命令行遇到改参数不可赋值,默认为true)变量解释说明help
- # (4)参数解析 args = parser.parse_args()
- #参数命名 -ts 更新策略 -p 保存输出数据文件夹 -u 是否预训练unet网络 --pretrain... 预训练unet路径 -k 权重的权重 --data 数据集路径
- # -b 每次训练的图像数量 --gan_type gan网络类型 --u_depth UNET网络深度 --d_depth 判别器深度 --downsampling 判别器下采样倍数
- # -lr 学习率 --beta1 Adam优化方法 -i 日志更新时间间隔 -e 训练的迭代次数 -l u和c中u的权重 -s 判别器d的loss权重
- # -g unet1和D中unet1的权重 -a unet1和D中D的权重 -t 所有变量的loss权重 --eta WGAN-GP中的梯度惩罚参数 --epsi 学习率指数衰减步长 --pretrained steps 预训练步长 --debug debug模式还是train模式 --gpu_counts GPU号 --local 数据位置
- def parse_args():
- parser = argparse.ArgumentParser(description='Training Custom Defined Model')
- parser.add_argument('--training_strategies', '-ts', default='update_u_d',
- choices=['update_c_d_u', 'add_normal_constraint', 'remove_lesion_constraint', 'update_u_d'],
- help='training strategies')
- parser.add_argument('--prefix', '-p', default='./output/COVID-CELL/20230925_1', type=str, help='parent folder to save output data')
- parser.add_argument('--power', '-k', type=int, default=2, help='power of weight')
- parser.add_argument('--dataset', type=str,
- default='Synapse', help='experiment_name')
- parser.add_argument('--data', type=str,
- default='./data_1/gan', choices=['./data_1/gan','./FLARE23dataset/train'],help='dataset dir')
- parser.add_argument('--num_classes', type=int,
- default=1, help='output channel of a single encoder-decoder')
- parser.add_argument('--batch_size', '-b', default=8, type=int, help='batch size per gpu')
- parser.add_argument('--gan_type', type=str, default='multi_scale',
- choices=['conv_bn_leaky_relu', 'multi_scale', 'top_multi_scale', 'middle_multi_scale','resnet18'],
- help='discriminator type')
- parser.add_argument('--d_depth', type=int, default=7, help='discriminator depth')
- parser.add_argument('--dowmsampling', type=int, default=4, help='dowmsampling times in discriminator')
- parser.add_argument('--d_lr', default=5e-4, type=float, help='learning rate of discriminator')
- parser.add_argument('--u_lr', default=2e-4, type=float, help='learning rate of generator')
- parser.add_argument('--beta1', type=float, default=0.5, help='beta1 in Adam')
- parser.add_argument('--interval', '-i', default=10, type=int, help='log print interval')
- parser.add_argument('--epochs', '-e', default=150, type=int, help='all of training epochs')
- parser.add_argument('--epochs_transition', '-et', default=65, type=int, help='epoch setting of localization to segmentation')
- parser.add_argument('--lmbda', '-l', default=10, type=float, help='weight of u between u and c')
- parser.add_argument('--sigma', '-s', default=0.01, type=float, help='weight of d loss')
- parser.add_argument('--gamma', '-g', default=200, type=float, help='weight of u in u & d')
- parser.add_argument('--alpha', '-a', default=0.01, type=float, help='weight of d in u & d')
- parser.add_argument('--theta', '-t', default=0, type=float, help='weight of total variation loss')
- parser.add_argument('--eta', type=float, default=10.0, help='gradient penalty')
- parser.add_argument('--epsi', type=float, default=1.0, help='learning rate exponential decay step')
- parser.add_argument('--pretrained_steps', type=int, default=0, help='pretrained steps')
- parser.add_argument('--debug', action='store_true', default=False, help='mode:training or debug')
- # parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
- parser.add_argument('--deterministic', type=int, default=1,
- help='whether use deterministic training')
- parser.add_argument('--img_size', type=int,
- default=224, help='input patch size of network input')
- parser.add_argument('--seed', type=int,
- default=1234, help='random seed')
- parser.add_argument('--cfg', type=str, default='./configs/swin_tiny_patch4_window7_224_lite.yaml', metavar="FILE",
- help='path to config file', )
- parser.add_argument("--opts", help="Modify config options by adding 'KEY VALUE' pairs. ", default=None, nargs='+', )
- parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
- parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
- help='no: no cache, '
- 'full: cache all data, '
- 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
- parser.add_argument('--resume', help='resume from checkpoint')
- parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
- parser.add_argument('--use-checkpoint', action='store_true',
- help="whether to use gradient checkpointing to save memory")
- parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
- help='mixed precision opt level, if O0, no amp is used')
- parser.add_argument('--tag', help='tag of experiment')
- parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
- parser.add_argument('--throughput', action='store_true', help='Test throughput only')
- parser.add_argument('--local', action='store_true', default=False, help='data location')
- return parser.parse_args()
-
- args = parse_args()
- print(type(args))
-
- def main():
- if not args.deterministic: #使用不确定的训练方式,不保证可重复性,使用cudnn进加速卷积运算时保证精度但大幅降低计算效率
- cudnn.benchmark = True
- cudnn.deterministic = False
- else: #保证可重复性,使用cudnn进行卷积运算时稍微牺牲精度提高计算效率(默认且首选)
- cudnn.benchmark = False
- cudnn.deterministic = True
-
- random.seed(args.seed) #python形式使用random函数时设置随机种子
- np.random.seed(args.seed) #numpy形式使用random函数时设置确定的随机种子
- torch.manual_seed(args.seed) #为CPU设置随机种子
- torch.cuda.manual_seed_all(args.seed) #为当前GPU设置随机种子
-
- if args.training_strategies == 'update_u_d':
- trainer = update_u_d.update_u_d(args) #实例化update_u_d的对象为trainer,输入参数为args对象
- script_path = './u_d/update_u_d.py' #存储update_u_d的py脚本路径
- print('training step:')
- print('(1)fix D, update G')
- print('(2)fix G, update D')
- else:
- raise ValueError('{} has not been implemented'.format(args.training_strategies))
- trainer.save_running_script(script_path) #保存更改前后的运行脚本,将更改前的路径为~/u_d/update_u_d.py文件代码保存到~/exp/exp-1/update_u_d.py文件中去
- trainer.main()
- trainer.save_log()
-
-
- if __name__ == '__main__':
- main()
|