|
- import sys
- sys.path.insert(0, '.')
- import os
- import os.path as osp
- import random
- import logging
- import time
- import json
- import argparse
- import numpy as np
- from tabulate import tabulate
-
- import math
-
- import datetime
-
- import torch
- import torch.nn as nn
- import torch.distributed as dist
- from torch.utils.data import DataLoader
- import torch.cuda.amp as amp
- import torch.optim as optim
- import torch.optim.lr_scheduler as lr_scheduler
-
-
- from _250_bisenetv2_backboneNotLoad import BiSeNetV2
-
- from _430_customZCDataSet_230503 import CustomZCDataSet
-
- import torch.nn.functional as F
- from tqdm import tqdm
-
-
- import torch.nn.functional as F
- from tqdm import tqdm
-
-
- from torch.utils.tensorboard import SummaryWriter
-
- class MscEvalV0(object):
-
- def __init__(self, scale=0.5, ignore_label=255):
- self.ignore_label = ignore_label
- self.scale = scale
-
- def __call__(self, net, dl, n_classes):
- ## evaluate
- hist = torch.zeros(n_classes, n_classes).cuda().detach()
- if dist.is_initialized() and dist.get_rank() != 0:
- diter = enumerate(dl)
- else:
- diter = enumerate(tqdm(dl))
- for i, (imgs, label) in diter:
- N, _, H, W = label.shape
-
- label = label.squeeze(1).cuda()
- size = label.size()[-2:]
-
- imgs = imgs.cuda()
-
- N, C, H, W = imgs.size()
- new_hw = [int(H * self.scale), int(W * self.scale)]
-
- imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
-
- logits = net(imgs)[0]
-
- logits = F.interpolate(logits, size=size,
- mode='bilinear', align_corners=True)
- probs = torch.softmax(logits, dim=1)
- preds = torch.argmax(probs, dim=1)
- keep = label != self.ignore_label
- hist += torch.bincount(
- label[keep] * n_classes + preds[keep],
- minlength=n_classes ** 2
- ).view(n_classes, n_classes).float()
- if dist.is_initialized():
- dist.all_reduce(hist, dist.ReduceOp.SUM)
- ious = hist.diag() / (hist.sum(dim=0) + hist.sum(dim=1) - hist.diag())
- miou = ious.mean()
- return miou.item()
-
-
-
- class OhemCELoss(nn.Module):
- """
- 算法本质:
- Ohem本质:核心思路是取所有损失大于阈值的像素点参与计算,但是最少也要保证取n_min个
- """
- def __init__(self, thresh, lb_ignore=255):
- super(OhemCELoss, self).__init__()
- # self.thresh = 0.3567
- self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).cuda()
- # self.lb_ignore = 255
- self.lb_ignore = lb_ignore
- self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='none')
-
- def forward(self, logits, labels):
- # logits: [2,11,1088,896] batch,classNum,height,width
- # labels: [2,1088,896] batch,height,width
-
- # 1、计算n_min(最少算多少个像素点)的大小
- # n_min的大小:一个batch的n张h*w的label图的所有的像素点的十六分之一
- # n_min: 121856
- n_min = labels[labels != self.lb_ignore].numel() // 16
- # 2、交叉熵预测得到loss之后,打平成一维的
- # loss.shape = (1949696,) 1949696 = 2 * 1088 * 896
- loss = self.criteria(logits, labels).view(-1)
- # 3、所有loss中大于阈值的,这边叫做loss hard,这些点才参与损失计算
- # 注意,这里是优化了pytorch中 Ohem 排序的,不然排序太耗时间了
- # loss_hard.shape = (140232,)
- loss_hard = loss[loss > self.thresh]
- # 4、如果总数小于了n_min,那么肯定要保证有n_min个
- if loss_hard.numel() < n_min:
- loss_hard, _ = loss.topk(n_min)
- # 5、如果参与的像素点的个数大于了n_min个,那么这些点都参与计算
- # loss_hard_mean = 0.7070
- loss_hard_mean = torch.mean(loss_hard)
- # 6、返回损失的均值
- # 7、为什么Ohem的损失不能很好的评估模型的损失
- # 因为Ohem对应的损失只考虑了大于阈值对应部分的损失,小于阈值部分的没有考虑
- return loss_hard_mean
-
- class ConfigClass():
- @staticmethod
- def getConfig():
- ## bisenetv2
- cfg = {
- "model_type":'bisenetv2',
- "n_classes":2,
- "num_aux_heads": 4,
- }
- print(cfg)
-
- return cfg
-
-
- class TrainClass():
- @staticmethod
- def trainMain():
-
-
- # ==========================================================
- # 10、初始化参数区
- # ==========================================================
- print("====================== 10、初始化参数区 ======================")
-
- cfg = ConfigClass.getConfig()
-
-
- # ==========================================================
- # 20、构建数据集
- # ==========================================================
- print("====================== 20、构建数据集 ======================")
-
- bmpImgRoot= r'/dataset/data/unused'
- labelImgRoot= r'/dataset/data/unused'
- trainTxtPath= r'/code/data/train.txt'
- testTxtPath= r"/code/data/test.txt"
-
- # cropSize = [928, 1120]
- cropSize = [896, 1288]
- # cropsize = [1536, 1024]
- # cropsize = [768, 512]
- # 230501: 只能缩,不能扩大,不然裁剪出来,可能就不是一个完整产品了
- randomScale = (0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0,1.025,1.05)
-
- rotateAngleRange = (-3, 3)
-
- # dl = get_data_loader(cfg, mode='train')
- trainDataSet = CustomZCDataSet(bmpImgRoot, labelImgRoot, trainTxtPath, cropsize=cropSize, mode="train",
- randomscale=randomScale,rotateAngleRange=rotateAngleRange,saveImg=False)
- # sampler = torch.utils.data.distributed.DistributedSampler(ds)
- trainDataLoader = DataLoader(trainDataSet,
- batch_size=2,
- shuffle=True,
- # num_workers = 4,
- pin_memory=True,
- drop_last = True
- )
- # exit(0)
- testDataSet = CustomZCDataSet(bmpImgRoot,labelImgRoot, testTxtPath, mode='val',randomscale=randomScale)
- # sampler_val = torch.utils.data.distributed.DistributedSampler(dsval)
- testDataLoader = DataLoader(testDataSet,
- batch_size=2,
- shuffle=False,
- # num_workers = 4,
- pin_memory=True,
- # drop_last = False
- )
-
-
- # ==========================================================
- # 30、初始化模型
- # ==========================================================
- print("====================== 30、初始化模型 ======================")
- model = BiSeNetV2(cfg["n_classes"])
- model.cuda()
- # TODO:230420:这里还要加载预训练参数
- # loadModelPath = "pth/ok_230420/lossMin_epoch_311_iter_8086_loss_Min_3.13.pth"
- # loadModelPath = r"pth_save/Group3/lossMin_epoch_176_iter_19888_loss_min_3.0558.pth"
- loadModelPath = r""
- isOtherModel = False
- # 接着上次开始训练,加载已有参数
- if len(loadModelPath):
- print("加载预训练参数")
-
- # 如果是加载别人的预训练模型,分类数不一样
- if isOtherModel:
- print("1、加载【别人的】预训练参数")
- weights_dict = torch.load(loadModelPath)
- newWeightsDict = {}
- for k, v in weights_dict.items():
- # 230504:发现报错信息是形如 .conv_out.1. 的形式,所以排除它
- if ".conv_out.1." not in k:
- newWeightsDict[k] = v
- # 不要求严格对等
- model.load_state_dict(newWeightsDict,strict=False)
- # 如果是加载自己的预预训练模型,分类数一样
- else:
- print("2、加载【自己的】预训练参数")
- model.load_state_dict(torch.load(loadModelPath))
-
- # print(model)
-
-
- isPrintModel = False
- if isPrintModel:
- # 打印网络模型
- input = torch.ones((2, 3, 320, 352))
- input = input.cuda()
- writer = SummaryWriter("logs/model")
- # 打印图像模型
- writer.add_graph(model, input)
- writer.close()
-
-
- # ==========================================================
- # 40、定义优化器
- # ==========================================================
- print("====================== 40、定义优化器 ======================")
- learning_rate = 1e-4
- # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.005)
- optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.005)
- # Scheduler https://arxiv.org/pdf/1812.01187.pdf
- # 200个epoch之后,学习率降为1/10
- learnSchedulerFunc = lambda x: ((1 + math.cos(x * math.pi / 200)) / 2) * (1 - learning_rate) + learning_rate # cosine
- scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=learnSchedulerFunc)
-
- # ==========================================================
- # 50、定义损失函数
- # ==========================================================
- print("====================== 50、定义损失函数 ======================")
- ignoreLabelVal = 255
- criteria_pre = OhemCELoss(0.7, ignoreLabelVal)
- criteria_aux = [OhemCELoss(0.7, ignoreLabelVal) for _ in range(cfg["num_aux_heads"])]
-
-
-
- # ==========================================================
- # 70、开始训练
- # ==========================================================
- print("====================== 70、开始训练 ======================")
-
- trainTotalIter = 0
- testTotalIter = 0
-
- mIOU50_max = -999999999
- mIOU75_max = -999999999
- trainEpochMeanLoss_min = 999999999
-
- todayStr = datetime.datetime.now().strftime('%Y%m%d')
- savePthPath_base = r"/model/pth"
- os.makedirs(savePthPath_base, exist_ok=True)
- # 230420:增加cropSize
- savePthPath = os.path.join(savePthPath_base, todayStr+"_cropSize_"+str(cropSize[0]) +"x"+str(cropSize[1]))
- os.makedirs(savePthPath, exist_ok=True)
-
-
- # 添加tensorboard
- writer = SummaryWriter("logs/loss_and_accuracy")
-
- len_dataLoader = len(trainDataLoader)
-
- for epoch in range(9000):
- print("-------第 {} 轮训练开始-------".format(epoch + 1))
- ##########################################
- # 10、训练
- ##########################################
- timeTrainModelBegin = time.time()
-
- # 初始化
- trainTotalLoss = 0.0
- iterNumPerEpoch = 0
-
- # 模型训练
- model.train()
- optimizer.zero_grad()
-
- # 具体训练
- for dl_step, data in enumerate(trainDataLoader):
- timeBatchBegin = time.time()
-
- trainTotalIter = trainTotalIter + 1
- iterNumPerEpoch = iterNumPerEpoch + 1
-
- inputs, labels = data[0], data[1]
-
- im = inputs.cuda()
- lb = labels.cuda()
-
- lb = torch.squeeze(lb, 1)
-
-
- logits, *logits_aux = model(im)
- loss_pre = criteria_pre(logits, lb)
- loss_aux = [crit(lgt, lb) for crit, lgt in zip(criteria_aux, logits_aux)]
- loss = loss_pre + sum(loss_aux)
-
-
- loss.backward()
-
- optimizer.step()
- optimizer.zero_grad()
-
- # 更新学习率
- scheduler.step()
-
- # 每个epoch总损失
- trainTotalLoss+=loss.item()
-
- # 打印损失
- if 0==(dl_step+1)%10 or (dl_step+1) == len_dataLoader:
- lr = scheduler.get_last_lr()
- lr = sum(lr) / len(lr)
- print("it: {}, lr: {}, loss: {},loss_pre: {},".format(dl_step + 1, lr,loss.item(), loss_pre.item()))
-
- # if 0==trainTotalIter%1000:
- # print("隔指定迭代dataloader保存模型: 当前iter为{},保存模型!".format(trainTotalIter))
- # # 保存模型
- # save_pth = osp.join(savePthPath, 'iterSave_epoch_{}_iter_{}_loss_{}_loss_pre_{}.pth'
- # .format(epoch + 1,trainTotalIter, str(round(loss.item(), 4)), str(round(loss_pre.item(), 4))))
- # modelState = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
- # torch.save(modelState, save_pth)
-
-
-
- # 每个epoch保存模型,并且打印损失
-
- # 每个epoch,打印总损失
- trainTotalLoss_mean = trainTotalLoss/len_dataLoader
- print("epoch: {}, trainTotalLoss_mean: {} ".format(epoch + 1,trainTotalLoss_mean))
-
- writer.add_scalar("train_loss", trainTotalLoss_mean, epoch + 1)
-
- # 保存模型
- if 0==epoch%2:
- print("隔指定回合保存模型: 当前epoch为{},保存模型!".format(epoch))
- save_pth = osp.join(savePthPath, 'epoch_{}_iter_{}_loss_mean_{}.pth'
- .format(epoch + 1, trainTotalIter, str(round(trainTotalLoss_mean, 4))))
- modelState = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
- torch.save(modelState, save_pth)
-
- # 保存损失最小的模型
- if trainTotalLoss_mean<=trainEpochMeanLoss_min:
- print("保存损失最小模型: 当前损失为{},保存模型!".format(trainTotalLoss_mean))
- trainEpochMeanLoss_min = trainTotalLoss_mean
- save_pth = osp.join(savePthPath, 'lossMin_epoch_{}_iter_{}_loss_min_{}.pth'
- .format(epoch + 1, trainTotalIter, str(round(trainEpochMeanLoss_min, 4))))
- modelState = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
- torch.save(modelState, save_pth)
-
-
- # 保存模型
- # if epoch>200 and 0==(epoch+1) % 20:
- #
- # # ## evaluator
- # print('计算 mIOU')
- # with torch.no_grad():
- # print("计算 mIOU50")
- # single_scale1 = MscEvalV0()
- # mIOU50 = single_scale1(model, trainDataLoader, cfg["n_classes"])
- #
- # print("计算 mIOU70")
- # single_scale2 = MscEvalV0(scale=0.75)
- # mIOU75 = single_scale2(model, trainDataLoader, cfg["n_classes"])
- #
- # writer.add_scalar("train_mIOU50", mIOU50, epoch + 1)
- # writer.add_scalar("train_mIOU75", mIOU75, epoch + 1)
- #
- #
- # # 计算验证集的mIOU
- # print('计算 验证集 mIOU')
- # with torch.no_grad():
- # print("计算 验证集 mIOU50")
- # single_scale1 = MscEvalV0()
- # mIOU50_test = single_scale1(model, testDataLoader, cfg["n_classes"])
- #
- # print("计算 验证集 mIOU70")
- # single_scale2 = MscEvalV0(scale=0.75)
- # mIOU75_test = single_scale2(model, testDataLoader, cfg["n_classes"])
- #
- # writer.add_scalar("test_mIOU50", mIOU50_test, epoch + 1)
- # writer.add_scalar("test_mIOU75", mIOU75_test, epoch + 1)
-
- # 感觉没有必要保存
- # if False:
- # save_pth = osp.join(savePthPath, 'mIOU_epoch_{}_iter_{}_loss_min_{}_mIOU50_{}_mIOU75_{}.pth'
- # .format(epoch + 1, trainTotalIter, str(round(trainEpochMeanLoss_min, 4)),str(round(mIOU50, 4)), str(round(mIOU75, 4))))
- # modelState = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
- # torch.save(modelState, save_pth)
-
-
- # if mIOU50 > mIOU50_max:
- # mIOU50_max = mIOU50
- # save_pth = osp.join(savePthPath,
- # 'mIOU50_max_epoch_{}_iter_{}_loss_min_{}_mIOU50_{}_mIOU75_{}.pth'
- # .format(epoch + 1, trainTotalIter, str(round(trainEpochMeanLoss_min, 4)),
- # str(round(mIOU50, 4)), str(round(mIOU75, 4))))
- # modelState = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
- # torch.save(modelState, save_pth)
-
-
- # if mIOU75 > mIOU75_max:
- # mIOU75_max = mIOU75
- # save_pth = osp.join(savePthPath,
- # 'mIOU75_max_epoch_{}_iter_{}_loss_min_{}_mIOU50_{}_mIOU75_{}.pth'
- # .format(epoch + 1, trainTotalIter, str(round(trainEpochMeanLoss_min, 4)),
- # str(round(mIOU50, 4)), str(round(mIOU75, 4))))
- # modelState = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
- # torch.save(modelState, save_pth)
-
-
- pass
-
-
-
- if __name__ == '__main__':
- TrainClass.trainMain()
-
- """
-
- """
-
-
-
|