|
- import os
- import time
- import numpy as np
- import matplotlib.pyplot as plt
- import torch
- from torch.utils.tensorboard import SummaryWriter
- from monai.data import decollate_batch
- from monai.metrics import DiceMetric
- from monai.inferers import sliding_window_inference
- from monai.transforms import (
- AsDiscrete,
- Compose,
- EnsureType,
- )
-
- from loss import LossFunction
-
-
- scaler = torch.cuda.amp.GradScaler()
-
- post_pred_o = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=14)])
- post_label_o = Compose([EnsureType(), AsDiscrete(to_onehot=14)])
- post_pred_c = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
- post_label_c = Compose([EnsureType(), AsDiscrete(to_onehot=2)])
-
- metric_o = DiceMetric(include_background=False, reduction="mean")
- metric_c = DiceMetric(include_background=False, reduction="mean")
-
-
- def training(model, train_ds, train_loader, val_loader, optimizer, lr_scheduler, device, args):
- loss_function = LossFunction(args)
- loss_weights = [args.lossw_dice, args.lossw_ce]
-
- best_c_metric = -1
- best_o_metric = -1
- best_c_epoch = -1
- best_o_epoch = -1
- epoch_loss_values = []
- writer = SummaryWriter(args.output_dir)
-
- demo_dir = os.path.join(args.output_dir, 'demo')
- if not os.path.exists(demo_dir):
- os.makedirs(demo_dir)
-
- for epoch in range(args.epoch_start, args.epoch_end):
- print("-" * 10)
- print(f"epoch {epoch + 1}/{args.epoch_end}")
-
- model.train()
- epoch_loss, epoch_dc_loss, epoch_ce_loss = 0, 0, 0
-
- step = 0
- for batch_data in train_loader:
- step += 1
- loss, loss_list = train_step(model, optimizer, batch_data, epoch, step, loss_function, loss_weights, demo_dir, device, args)
- epoch_loss += loss.item()
- epoch_dc_loss += loss_list[0].item()
- epoch_ce_loss += loss_list[1].item()
- print(
- f"{step}/{len(train_loader)}"
- f", epoch_loss: {loss.item():.4f}"
- f", dice_loss: {loss_list[0].item():.4f}"
- f", ce_loss: {loss_list[1].item():.4f}"
- )
-
- lr_scheduler.step()
- epoch_loss /= step
- epoch_dc_loss /= step
- epoch_ce_loss /= step
- epoch_loss_values.append([epoch_loss, epoch_dc_loss, epoch_ce_loss])
-
- writer.add_scalar("train_loss", epoch_loss, epoch)
- writer.add_scalar("train_dc_loss", epoch_dc_loss, epoch)
- writer.add_scalar("train_ce_loss", epoch_ce_loss, epoch)
- writer.add_scalar("learning_rate", optimizer.param_groups[0]['lr'], epoch)
-
- print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
- print(
- f"current epoch: {epoch + 1}, current train dice: {epoch_dc_loss:.4f}, current train ce: {epoch_ce_loss:.4f}"
- )
-
- checkpoint = {
- 'model_state_dict': model.state_dict(),
- 'optimizer': optimizer.state_dict(),
- # "epoch": epoch,
- 'scheduler': lr_scheduler.state_dict()
- }
-
- if (epoch + 1) % args.val_interval == 0:
- dice_cancer, dice_organ = valid_step(model, val_loader, device, args)
- writer.add_scalar("valid_cancer_dice", dice_cancer, epoch)
- writer.add_scalar("valid_organ_dice", dice_organ, epoch)
-
- if dice_cancer > best_c_metric:
- best_c_metric = dice_cancer
- best_c_epoch = epoch + 1
- if dice_organ > best_o_metric:
- best_o_metric = dice_organ
- best_o_epoch = epoch + 1
-
- torch.save(
- checkpoint,
- os.path.join(
- args.output_dir,
- f"epoch{epoch + 1}-cancer_dc{dice_cancer:.4f}-organ_dc{dice_organ:.4f}.pth"
- )
- )
-
- print(
- f"current cancer dice {dice_cancer:.4f}, organ dice {dice_organ:.4f}"
- f"\nbest cancer dice {best_c_metric:.4f} at epoch {best_c_epoch}"
- f"\nbest organ dice {best_o_metric:.4f} at epoch {best_o_epoch}"
- )
-
- torch.save(
- checkpoint,
- os.path.join(
- args.output_dir,
- f"epoch{epoch + 1}-dc_l{epoch_dc_loss:.4f}-ce_l{epoch_ce_loss:.4f}.pth"
- )
- )
-
- time.sleep(0.003)
- return
-
-
- def train_step(model, optimizer, batch_data, epoch, step, loss_function, loss_weights, demo_dir, device, args):
- inputs, labels = (
- batch_data["image"].to(device),
- batch_data["label"].to(device),
- )
- optimizer.zero_grad()
- with torch.cuda.amp.autocast():
- outputs = model(inputs)
- loss_list = losses(outputs, labels, epoch, loss_function, device, args) # loss_list = (dc_loss, ce_loss, tp_loss)
-
- loss = 0
- for i in range(len(loss_list)):
- loss += loss_list[i] * loss_weights[i]
-
- scaler.scale(loss).backward()
- scaler.step(optimizer)
- scaler.update()
-
- if (step-1) % args.demo_interval == 0:
- plot_demo(inputs, outputs, labels, epoch, step, demo_dir, args)
-
- return loss, loss_list
-
-
- def valid_step(model, val_loader, device, args):
- model.eval()
- with torch.no_grad():
- for val_data in val_loader:
- val_inputs, val_labels = (
- val_data["image"].to(device),
- val_data["label"].to(device),
- )
- organ_label, cancer_label = new_label(val_labels, device)
-
- roi_size = args.image_size
- sw_batch_size = args.sw_batch_size
- cancer_output, organ_output = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model, overlap=0.5)
- cancer_output = [post_pred_c(i) for i in decollate_batch(cancer_output)]
- cancer_label = [post_label_c(i) for i in decollate_batch(cancer_label)]
- organ_output = [post_pred_o(i) for i in decollate_batch(organ_output)]
- organ_label = [post_label_o(i) for i in decollate_batch(organ_label)]
-
- metric_c(y_pred=cancer_output, y=cancer_label)
- metric_o(y_pred=organ_output, y=organ_label)
-
- dice_c = metric_c.aggregate().item()
- dice_o = metric_o.aggregate().item()
- metric_c.reset()
- metric_o.reset()
- return dice_c, dice_o
-
-
- def new_label(labels, device):
- organ = torch.where(labels!=14, labels, 0).to(device)
- cancer = torch.where(labels==14, 1, 0).to(device)
- return organ, cancer
-
- def pred_label(pred):
- prd_cancer = pred[0].argmax(dim=1)[0].cpu().numpy()
- prd_organ = pred[1].argmax(dim=1)[0].cpu().numpy()
- prd = np.where(prd_cancer==1, 14, prd_organ)
- return prd
-
-
- def losses(outputs, labels, epoch, loss_function, device, args):
- organ, cancer = new_label(labels, device)
-
- cancer_output, organ_output = outputs[0], outputs[1]
-
- cancer_loss_list = loss_function(cancer_output, cancer, epoch, num_classes=2) # loss_list = (dc_loss, ce_loss)
- organ_loss_list = loss_function(organ_output, organ, epoch, num_classes=14) # loss_list = (dc_loss, ce_loss)
-
- loss_list = [args.c_w * cancer_loss_list[0] + args.o_w * organ_loss_list[0], # dice loss
- args.c_w * cancer_loss_list[1] + args.o_w * organ_loss_list[1],] # ce loss
-
- return loss_list
-
-
- def plot_demo(image, pred, label, epoch, step, demo_dir, args):
- img = image[0, 0].cpu().numpy()
- prd = pred_label(pred)
- lab = label[0][0].cpu().numpy()
- a = np.concatenate([img, img], axis=0)
- b = np.concatenate([lab, prd], axis=0)
- plt.imshow(a, cmap='gray')
- plt.imshow(b, cmap='hot', alpha=0.5)
- f = plt.gcf() #获取当前图像
- f.savefig(f'{demo_dir}/train-epoch{epoch+1}_step{step}.jpg')
- f.clear() #释放内存
- return
|