|
- import glob
- import os
-
- from utils.CD_dataset import CDDataset
-
- try:
- """这部分代码是修复镜像错误的,直接运行就行不用管,建议不要随便动"""
- os.system(f"apt-get update -y")
- os.system(f"apt install libgl1-mesa-glx -y")
- os.system(f"apt-get install libglib2.0-dev -y")
-
- print("修复成功")
- except:
- print("修复失败")
- try:
- """这部分代码是自动配置环境的,其中 requirements.txt 文件里面主要写了一些所需要的环境包"""
- os.system(f"pip install -r requirements.txt -i https://pypi.douban.com/simple/")
- print("修复成功")
- except:
- print("修复失败")
- import time
- import datetime
- import torch
- from utils.c_train_and_eval import train_one_epoch, evaluate, create_lr_scheduler
- from utils.change_data import MyDataset
- from models.models.snunet import SNUNet
- from utils.c_distributed_utils import ConfusionMatrix
- from utils.distributed_utils import set_seed
- from models.FC.FC_EF import FresUNet
- from models.models.siamunet_conc import SiamUNet_conc
- from models.models.siamunet_diff import SiamUNet_diff
- from models.models.stanet import STANet
- from models.models.ifn import DSIFN
- from models.models.cdnet import CDNet
- from models.models import unet
- # from models.models.myne_corsst import MyNet
- # from models.models.mynet_corsst3 import MyNet
- # from models.models.mynet import MyNet
- from models.mynet.mynet3 import MyNet
- import warnings
- from models.models.DMINet import DMINet
- # from models.mynet.CDNet import MyNet
- # from models.MGL.mynet import MyNet
- import glob
- from utils import EMA
- from ema_pytorch import EMA
- warnings.filterwarnings("ignore")
- """
- 读取数据集:RGB三通道,0-255范围内
- 变化检测
- """
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
-
-
- # ceshi
- def main(args):
- device = torch.device(args.device if torch.cuda.is_available() else "cpu")
-
- batch_size = args.batch_size
- num_classes = args.num_classes + 1
-
- # 用来保存训练以及验证过程中信息
- results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
- train_dataset = MyDataset(args.data_path)
- val_dataset = MyDataset(args.val_path)
- # train_dataset = CDDataset(root_dir=args.data_path, split='train',
- # img_size=256, is_train=True,
- # label_transform="norm")
- # val_dataset = CDDataset(root_dir=args.val_path, split="val",
- # img_size=256, is_train=False,
- # label_transform="norm")
- num_workers = 4
- print(num_workers)
- train_loader = torch.utils.data.DataLoader(train_dataset,
- batch_size=batch_size,
- num_workers=num_workers,
- shuffle=True,
- pin_memory=True
- )
-
- val_loader = torch.utils.data.DataLoader(val_dataset,
- batch_size=batch_size,
- num_workers=num_workers,
- pin_memory=True
- )
- # model = SNUNet(3, 2)
-
- model = MyNet(3, 2)
- # model = DMINet(num_classes=2)
- # model = FresUNet(6, 2)
- # model = SiamUNet_conc(3, 2)
- # model = SiamUNet_diff(3, 2)
- # model = STANet(in_ch=3)
- model.to(device)
- model_ema = None
- if args.model_ema:
- model_ema = EMA(
- model,
- beta=0.9999, # exponential moving average factor
- update_after_step=32, # only after this number of .update() calls will it start updating
- update_every=1, # how often to actually update, to save on compute (updates every 10th .update() call)
- )
-
- # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
- # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.2, last_epoch=-1)
- # # 学习率设置
- # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=25, last_epoch=-1)
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
- lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs,
- warmup=True, warmup_epochs=1)
- # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
- # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.2, last_epoch=-1)
-
- # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,momentum=0.9, weight_decay=args.weight_decay)
- # lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, warmup=True, warmup_epochs=3)
- scaler = torch.cuda.amp.GradScaler() if args.amp else None
-
- # if args.ckpt_url:
- # print("使用预训练模型", args.ckpt_url)
- # checkpoint = torch.load(args.ckpt_url, map_location='cpu')
- # model.load_state_dict(checkpoint['model'])
- # 是否继续训练
- if args.resume:
- checkpoint = torch.load(args.resume, map_location='cpu')
- model.load_state_dict(checkpoint['model'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
- args.start_epoch = checkpoint['epoch'] + 1
- # 混合精度训练
- if args.amp:
- scaler.load_state_dict(checkpoint["scaler"])
-
- # 开始时间
- start_time = time.time()
- best_F1 = 0.
- Last_epoch = 0
- for epoch in range(args.start_epoch, args.epochs):
- # print_model_info(model)
- mean_loss, lr = train_one_epoch(
- model, optimizer, train_loader, device, epoch,
- lr_scheduler=lr_scheduler,
- print_freq=args.print_freq,
- num_classes=2,
- scaler=scaler, ema=model_ema)
- confmat = evaluate(model, val_loader,
- device=device,
- num_classes=num_classes, print_freq=args.print_freq)
- val_info = ConfusionMatrix.todict(confmat)
- val_info_print = str(confmat)
- F1 = float(val_info['F1_Score'][1])
- print(val_info_print)
- if model_ema:
- confmat = evaluate(model_ema, val_loader,
- device=device,
- num_classes=num_classes, print_freq=args.print_freq)
- val_info = ConfusionMatrix.todict(confmat)
- val_info_print = str(confmat)
- F1 = float(val_info['F1_Score'][1])
- print(val_info_print)
-
- if F1 == "nan":
- F1 = 0
- else:
- F1 = float(F1)
- save_txt = os.path.join(args.out_path, results_file)
- print(save_txt)
- with open(save_txt, "a") as f:
- # 记录每个epoch对应的train_loss、lr以及验证集各指标
- train_info = f"[epoch: {epoch}]\n" \
- f"train_loss: {mean_loss:.4f}\n" \
- f"lr: {lr:.6f}\n"
-
- f.write(train_info + val_info_print + "\n\n")
-
- save_file = {"model": model.state_dict(),
- "optimizer": optimizer.state_dict(),
- "lr_scheduler": lr_scheduler.state_dict(),
- "epoch": epoch,
- "args": args}
- if args.amp:
- save_file["scaler"] = scaler.state_dict()
- if F1 > best_F1:
- best_F1 = F1
- Last_epoch = epoch
- # if epoch > 90:
- # 服务器保存模型地址
- # model_ema_name = str(epoch) + "model_ema_best.pth"
- # model_name = str(epoch) + "model_best.pth"
- # model_ema_name = "model_ema_best.pth"
- model_name = "model_best.pth"
- # save_url_ema = os.path.join(args.out_path, model_ema_name)
- save_url = os.path.join(args.out_path, model_name)
- print(save_url)
- torch.save(model, save_url)
- # torch.save(model_ema, save_url_ema)
-
-
- print("Best:", best_F1, )
- print("Best_epoch:", Last_epoch)
- print("best model in {} epoch".format(Last_epoch))
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print("training time {}".format(total_time_str))
-
- def parse_args():
- import argparse
- parser = argparse.ArgumentParser(description="pytorch fcn training")
- parser.add_argument("--ckpt_url", default="", help="data root")
- parser.add_argument("--data_path", default="/tmp/dataset/train", help="data root")
- parser.add_argument("--val_path", default="/tmp/dataset/val", help="val root")
- parser.add_argument("--num-classes", default=1, type=int)
- parser.add_argument("--device", default="cuda", help="training device")
- parser.add_argument("--out_path", default="/tmp/output", help="val root")
- parser.add_argument("-b", "--batch-size", default=32, type=int)
- parser.add_argument("--epochs", default=100, type=int, metavar="N",
- help="number of total epochs to train")
- parser.add_argument('--lr', default=0.0002, type=float, help='initial learning rate')
- parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
- help='momentum')
- parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
- metavar='W', help='weight decay (default: 1e-4)',
- dest='weight_decay')
- parser.add_argument('--print-freq', default=100, type=int, help='print frequency')
- parser.add_argument('--resume', default='', help='resume from checkpoint')
- parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
- help='start epoch')
- parser.add_argument("--amp", default=False, type=bool,
- help="Use torch.cuda.amp for mixed precision training")
- parser.add_argument("--model_ema", default=False, type=bool,
- help="Use torch.cuda.amp for mixed precision training")
- parser.add_argument("--model_ema_decay", default=0.99998, type=float,
- help="Use torch.cuda.amp for mixed precision training")
- parser.add_argument("--model_ema_steps", default=32, type=int,
- help="Use torch.cuda.amp for mixed precision training")
- parser.add_argument("--seed", default=10, type=int,
- help="Use torch.cuda.amp for mixed precision training")
- args = parser.parse_args()
-
-
- return args
-
-
-
- if __name__ == '__main__':
- # 智算网络集群训练脚本自动化配置
- args = parse_args()
- # set_seed(args)
-
- main(args)
|