|
- import os
- import argparse
- import torch
- from monai.utils import set_determinism
- from torch.optim.lr_scheduler import PolynomialLR
- from detectron2.solver.lr_scheduler import WarmupMultiStepLR
-
- from dataset import get_loader
- from models import choose_model
- from trainer import training
-
-
- def parse_args():
- parser = argparse.ArgumentParser(description="Cardiac segmentation pipeline")
-
- parser.add_argument("--finetune", action="store_true", help="finetune a pretrained model, else train from scratch")
- parser.add_argument("--arch", default='vit_base', type=str, help="type of ViT")
- parser.add_argument("--a_min", default=-200.0, type=float, help="a_min in ScaleIntensityRanged")
- parser.add_argument("--a_max", default=300.0, type=float, help="a_max in ScaleIntensityRanged")
- parser.add_argument("--batch_size", default=48, type=int, help="number of batch size")
- parser.add_argument("--cpu_num", default=8, type=int, help="number of cpu")
- parser.add_argument("--cache_rate", default=0.0, type=float, help="cache rate in CacheDataset")
- parser.add_argument("--cache_num", default=96, type=int, help="number of samples cache in CacheDataset")
- parser.add_argument("--dataset", default="", type=str, help="dataset")
- parser.add_argument("--data_dir", default="", type=str, help="dataset directory")
- parser.add_argument("--demo_interval", default=20, type=int, help="plotting demo for every demo_interval step")
- parser.add_argument("--epoch_end", default=200, type=int, help="the end epoch of training")
- parser.add_argument("--epoch_start", default=0, type=int, help="the start epoch of training")
- parser.add_argument("--freeze", default="", type=str, help="freeze some weights or not")
- parser.add_argument("--gpu", default="1", type=str, help="gpu id")
- parser.add_argument("--input_size", default=224, type=int, help="image size for network input")
- parser.add_argument("--in_channels", default=1, type=int, help="number of input channels")
- parser.add_argument("--include_background", default=False, type=bool, help="whether to include background when calculate dice loss")
- parser.add_argument("--lr", default=5e-4, type=float, help="learning rate")
- parser.add_argument("--lr_scheduler", default="", type=str, help="learning schedular")
- parser.add_argument("--lr_decay_epoch", default=175, type=int, help="epoch learning rate decay")
- parser.add_argument("--lossw_ce", default=1.0, type=float, help="weight for ce loss")
- parser.add_argument("--lossw_dice", default=1.0, type=float, help="weight for dice loss")
- parser.add_argument("--c_w", default=0.5, type=float, help="weight for cancer")
- parser.add_argument("--o_w", default=1.0, type=float, help="weight for organs")
- parser.add_argument('--local_rank', default=-1, type=int, help="node rank for distributed training")
- parser.add_argument("--model_name", default='', type=str, help="network used for segmrntation")
- parser.add_argument("--norm", default='batch', type=str, help="network used for segmrntation")
- parser.add_argument("--num_pos", default=1, type=int, help="number of positive samples for RandCropByPosNegLabeld")
- parser.add_argument("--num_neg", default=3, type=int, help="number of negative samples for RandCropByPosNegLabeld")
- parser.add_argument("--num_samples", default=4, type=int, help="number of samples for RandCropByPosNegLabeld")
- parser.add_argument("--num_classes", default=15, type=int, help="number of segmentation classes, including background")
- parser.add_argument("--output_dir", default="/home/models/FLARE23/exp24_unet2d_o5c_plabel", type=str, help="directory to save the outputs")
- parser.add_argument("--plot_col", default=8, type=int, help="number of columns in demo")
- parser.add_argument("--plot_row", default=3, type=int, help="number of rows in demo")
- parser.add_argument("--plot_slices", default=4, type=int, help="number of slice interval in demo")
- parser.add_argument("--pretrained_model", default="", type=str, help="pretrained model path")
- parser.add_argument("--resume_ckpt", default='', type=str, help="resume training from pretrained checkpoint")
- parser.add_argument('--seed', default=0, type=int, help="random seed")
- parser.add_argument("--sw_batch_size", default=4, type=int, help="number of sliding window batch size")
- parser.add_argument("--val_interval", default=1, type=int, help="number of intervals to validate and save models")
- parser.add_argument("--warm_up", default=20, type=int, help="warm up epochs")
- parser.add_argument("--workers", default=8, type=int, help="number of workers")
-
- return parser.parse_args()
-
-
- def main():
- args = parse_args()
- print(args)
- set_determinism(seed=args.seed)
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_DEVICE_ORDER"] = args.gpu
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- torch.backends.cudnn.benchmark = True
-
- cpu_num = args.cpu_num # 这里设置成你想运行的CPU个数
- os.environ ['OMP_NUM_THREADS'] = str(cpu_num)
- os.environ ['OPENBLAS_NUM_THREADS'] = str(cpu_num)
- os.environ ['MKL_NUM_THREADS'] = str(cpu_num)
- os.environ ['VECLIB_MAXIMUM_THREADS'] = str(cpu_num)
- os.environ ['NUMEXPR_NUM_THREADS'] = str(cpu_num)
- torch.set_num_threads(cpu_num)
-
- args.image_size = (args.input_size, args.input_size)
- model = choose_model(args)
- model = torch.nn.DataParallel(model).to(device)
-
- if args.resume_ckpt:
- checkpoint = torch.load(args.resume_ckpt, map_location=device)
- model_state_dict = checkpoint['model_state_dict']
- out = model.load_state_dict(model_state_dict)
-
- if args.freeze:
- for key, value in model.named_parameters():
- k = key.split('.')[1]
- if args.freeze == 'cancer':
- if k not in ['upcat_2_2', 'upcat_1_2', 'final_conv_2']: # weights for organs
- value.requires_grad = False # freeze weights for cancer
- elif args.freeze == 'organ':
- if k not in ['upcat_2_1', 'upcat_1_1', 'final_conv_1']: # weights for cancer
- value.requires_grad = False # freeze weights for organ
- print(out)
-
- else:
- print("=> no checkpoint found at '{}'".format(args.resume_ckpt))
-
- params_bp = [p for p in model.parameters() if p.requires_grad]
- print("Total parameters count:", len(params_bp))
-
- optimizer = torch.optim.AdamW(
- params_bp, #[{'params': params_bp}],
- lr = args.lr,
- weight_decay = 1e-5,
- )
- if args.lr_scheduler == 'warm_up':
- lr_scheduler = WarmupMultiStepLR(
- optimizer = optimizer,
- milestones = [args.lr_decay_epoch],
- gamma = 0.1,
- warmup_factor = args.lr,
- warmup_iters = args.warm_up,
- warmup_method = "linear",
- last_epoch = -1,
- )
- elif args.lr_scheduler == 'poly':
- lr_scheduler = PolynomialLR(optimizer, total_iters=args.epoch_end, power=0.9)
-
-
- train_ds, train_loader, val_loader = get_loader(args)
- print("Get dataloader!")
-
- training(model, train_ds, train_loader, val_loader, optimizer, lr_scheduler, device, args)
-
-
- if __name__ == '__main__':
- main()
|