|
- # Copyright 2020 MONAI Consortium
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- # http://www.apache.org/licenses/LICENSE-2.0
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import logging
- import os
- import sys
- import argparse
- from glob import glob
-
- import torch
- from PIL import Image
- from torch.utils.data import DataLoader
- from torch.utils.tensorboard import SummaryWriter
-
- import monai
- from monai.data import create_test_image_2d, list_data_collate, decollate_batch
- from monai.inferers import sliding_window_inference
- from monai.metrics import DiceMetric
- from monai.transforms import (
- # Activations,
- AddChanneld,
- AsDiscrete,
- Compose,
- LoadImaged,
- # RandCropByPosNegLabeld,
- RandRotate90d,
- ScaleIntensityd,
- EnsureTyped,
- EnsureType,
- RandGaussianNoised
- )
- # from monai.visualize import plot_2d_or_3d_image
-
-
- def main(args):
- monai.config.print_config()
- logging.basicConfig(stream=sys.stdout, level=logging.INFO)
-
- # create a temporary directory and 40 random image, mask pairs
- # print(f"generating synthetic data to {tempdir} (this may take a while)")
- # for i in range(40):
- # im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
- # Image.fromarray((im * 255).astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
- # Image.fromarray((seg * 255).astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))
-
- val_ratio = 0.2
- images = sorted(glob(os.path.join(args.data_dir, 'images', '*.png')))
- segs = sorted(glob(os.path.join(args.data_dir, 'masks', '*.png')))
- num_val = int(val_ratio * len(images))
- train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:num_val], segs[:num_val])]
- val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-num_val:], segs[-num_val:])]
-
- # define transforms for image and segmentation
- train_transforms = Compose(
- [
- LoadImaged(keys=["img", "seg"]),
- AddChanneld(keys=["img", "seg"]),
- ScaleIntensityd(keys=["img", "seg"]),
- # RandCropByPosNegLabeld(
- # keys=["img", "seg"], label_key="seg", spatial_size=[96, 96], pos=1, neg=1, num_samples=4
- # ),
- RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 1]),
- RandGaussianNoised(keys=["img", "seg"], prob = 0.5, std = 0.02),
- EnsureTyped(keys=["img", "seg"]),
- ]
- )
- val_transforms = Compose(
- [
- LoadImaged(keys=["img", "seg"]),
- AddChanneld(keys=["img", "seg"]),
- ScaleIntensityd(keys=["img", "seg"]),
- EnsureTyped(keys=["img", "seg"]),
- ]
- )
-
- # define dataset, data loader
- check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
- # # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
- check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate)
- check_data = monai.utils.misc.first(check_loader)
- print(check_data["img"].shape, check_data["seg"].shape)
-
- # create a training data loader
- train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
- # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
- train_loader = DataLoader(
- train_ds,
- batch_size=args.batch_size,
- shuffle=True,
- num_workers=4,
- collate_fn=list_data_collate,
- pin_memory=torch.cuda.is_available(),
- )
- # create a validation data loader
- val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
- val_loader = DataLoader(val_ds, batch_size=8, num_workers=4, collate_fn=list_data_collate)
-
- dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
- post_trans = Compose([EnsureType(),
- # Activations(sigmoid=True),
- AsDiscrete(
- argmax=True,
- to_onehot=True,
- num_classes=args.number_class)])
- post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=args.number_class)])
-
- # create UNet, DiceLoss and Adam optimizer
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- model = monai.networks.nets.UNet(
- spatial_dims=2,
- in_channels=1,
- out_channels=args.number_class,
- channels=(16, 32, 64, 128, 256),
- strides=(2, 2, 2, 2),
- num_res_units=2,
- ).to(device)
- # loss_function = monai.losses.DiceLoss(sigmoid=True)
- loss_function = torch.nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
- weight_decay = args.weight_decay,)
-
- # start a typical PyTorch training
- val_interval = 1
- best_metric = -1
- best_metric_epoch = -1
- epoch_loss_values = list()
- metric_values = list()
- writer = SummaryWriter(log_dir=args.save_ckpt_path)
- for epoch in range(args.epochs):
- print("-" * 10)
- print(f"epoch {epoch + 1}/{args.epochs}")
- model.train()
- epoch_loss = 0
- step = 0
- for batch_data in train_loader:
- step += 1
- inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
- optimizer.zero_grad()
- outputs = model(inputs)
- loss = loss_function(outputs, labels.squeeze(dim=1).long())
- loss.backward()
- optimizer.step()
- epoch_loss += loss.item()
- epoch_len = len(train_ds) // train_loader.batch_size
- print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
- writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
- epoch_loss /= step
- epoch_loss_values.append(epoch_loss)
- print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
-
- if (epoch + 1) % val_interval == 0:
- model.eval()
- with torch.no_grad():
- val_images = None
- val_labels = None
- val_outputs = None
- for val_data in val_loader:
- val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
- val_outputs = model(val_images)
- # val_labels = val_labels.squeeze(dim=1)
- roi_size = (448, 448)
- sw_batch_size = 4
- val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
- val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
- val_labels = [post_label(i) for i in decollate_batch(val_labels)]
-
- # compute metric for current iteration
- dice_metric(y_pred=val_outputs, y=val_labels)
- # aggregate the final mean dice result
- metric = dice_metric.aggregate().item()
- # reset the status for next validation round
- dice_metric.reset()
- metric_values.append(metric)
- if metric > best_metric:
- best_metric = metric
- best_metric_epoch = epoch + 1
- torch.save(model.state_dict(), "{}/best_metric_model_segmentation2d_dict.pth".format(args.save_ckpt_path))
- print("saved new best metric model")
- print(
- "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
- epoch + 1, metric, best_metric, best_metric_epoch
- )
- )
- writer.add_scalar("val_mean_dice", metric, epoch + 1)
- # plot the last model output as GIF image in TensorBoard with the corresponding image and label
- # plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
- # plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
- # plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
-
- print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
- writer.close()
-
-
- def get_args():
- parser = argparse.ArgumentParser(description = "Train Unet on images and target mask",
- formatter_class = argparse.ArgumentDefaultsHelpFormatter)
- # parser.add_argument('--device_id', type = int, default = 0, help = 'device id')
- parser.add_argument('--data_path', type = str, default = "", help = 'data directory')
- parser.add_argument('--batch_size', type = int, default = 16, help = 'train batch size, default: 16')
- parser.add_argument('--lr', type = float, default = 0.0001, help = 'learning rate, default: 0.0001')
- parser.add_argument('--epochs', type = int, default = 4000, help = 'epoch number, default: 100')
- parser.add_argument('--number_class', type = int, default = 2, help = 'classification number, default: 6')
- parser.add_argument('--save_ckpt_path', type = str, default = './ckpt', help = 'save checkpoint path')
- parser.add_argument('--weight_decay', type = float, default = 0.0005, help = 'weight decay')
-
- # parser.add_argument('--eval_batch_size', type = int, default = 8, help = 'eval batch size, default: 8')
- # parser.add_argument('--sink_steps', type = int, default = 100, help = 'data sink steps, default: 100')
- # parser.add_argument('--keep_ckpt_max', type = int, default = 1,
- # help = 'max number of saving checkpoint, default: 5')
- # parser.add_argument('--loss_scale', type = float, default = 1024.0, help = 'loss scale value, default: 1024.0')
- # parser.add_argument('--loss_function', type = str, choices = ['CE', 'Dice_Loss'], default = 'CE')
- # parser.add_argument('--load_ckpt_path', type = str, default = '', help = 'load checkpoint path')
-
- return parser.parse_args()
-
-
- if __name__ == "__main__":
- # Default
- args = get_args()
- args.data_dir = '/data5/jiaxin/data/2d_lung_seg'
-
- print(args)
- main(args)
|