|
- #!/usr/bin/env python
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- """
- Detectron2 training script with a plain training loop.
-
- This script reads a given config file and runs the training or evaluation.
- It is an entry point that is able to train standard models in detectron2.
-
- In order to let one script support training of many models,
- this script contains logic that are specific to these built-in models and therefore
- may not be suitable for your own project.
- For example, your research project perhaps only needs a single "evaluator".
-
- Therefore, we recommend you to use detectron2 as a library and take
- this file as an example of how to use the library.
- You may want to write your own script with your datasets and other customizations.
-
- Compared to "train_net.py", this script supports fewer default features.
- It also includes fewer abstraction, therefore is easier to add custom logic.
- """
-
- import logging
- import os
- from collections import OrderedDict
- import torch
- from torch.nn.parallel import DistributedDataParallel
-
- import detectron2.utils.comm as comm
- from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer
- from detectron2.config import get_cfg
- from detectron2.data import (
- MetadataCatalog,
- build_detection_test_loader,
- build_detection_train_loader,
- )
- from detectron2.engine import default_argument_parser, default_setup, launch
- from detectron2.evaluation import (
- CityscapesInstanceEvaluator,
- CityscapesSemSegEvaluator,
- COCOEvaluator,
- COCOPanopticEvaluator,
- DatasetEvaluators,
- LVISEvaluator,
- PascalVOCDetectionEvaluator,
- SemSegEvaluator,
- inference_on_dataset,
- print_csv_format,
- )
- from detectron2.modeling import build_model
- from detectron2.solver import build_lr_scheduler, build_optimizer
- from detectron2.utils.events import (
- CommonMetricPrinter,
- EventStorage,
- JSONWriter,
- TensorboardXWriter,
- )
-
- logger = logging.getLogger("detectron2")
-
-
- def get_evaluator(cfg, dataset_name, output_folder=None):
- """
- Create evaluator(s) for a given dataset.
- This uses the special metadata "evaluator_type" associated with each builtin dataset.
- For your own dataset, you can simply create an evaluator manually in your
- script and do not have to worry about the hacky if-else logic here.
- """
- if output_folder is None:
- output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
- evaluator_list = []
- evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
- if evaluator_type in ["sem_seg", "coco_panoptic_seg"]:
- evaluator_list.append(
- SemSegEvaluator(
- dataset_name,
- distributed=True,
- num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
- ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
- output_dir=output_folder,
- )
- )
- if evaluator_type in ["coco", "coco_panoptic_seg"]:
- evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder))
- if evaluator_type == "coco_panoptic_seg":
- evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
- if evaluator_type == "cityscapes_instance":
- assert (
- torch.cuda.device_count() >= comm.get_rank()
- ), "CityscapesEvaluator currently do not work with multiple machines."
- return CityscapesInstanceEvaluator(dataset_name)
- if evaluator_type == "cityscapes_sem_seg":
- assert (
- torch.cuda.device_count() >= comm.get_rank()
- ), "CityscapesEvaluator currently do not work with multiple machines."
- return CityscapesSemSegEvaluator(dataset_name)
- if evaluator_type == "pascal_voc":
- return PascalVOCDetectionEvaluator(dataset_name)
- if evaluator_type == "lvis":
- return LVISEvaluator(dataset_name, cfg, True, output_folder)
- if len(evaluator_list) == 0:
- raise NotImplementedError(
- "no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type)
- )
- if len(evaluator_list) == 1:
- return evaluator_list[0]
- return DatasetEvaluators(evaluator_list)
-
-
- def do_test(cfg, model):
- results = OrderedDict()
- for dataset_name in cfg.DATASETS.TEST:
- data_loader = build_detection_test_loader(cfg, dataset_name)
- evaluator = get_evaluator(
- cfg, dataset_name, os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
- )
- results_i = inference_on_dataset(model, data_loader, evaluator)
- results[dataset_name] = results_i
- if comm.is_main_process():
- logger.info("Evaluation results for {} in csv format:".format(dataset_name))
- print_csv_format(results_i)
- if len(results) == 1:
- results = list(results.values())[0]
- return results
-
-
- def do_train(cfg, model, resume=False):
- model.train()
- optimizer = build_optimizer(cfg, model)
- scheduler = build_lr_scheduler(cfg, optimizer)
-
- checkpointer = DetectionCheckpointer(
- model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
- )
- start_iter = (
- checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
- )
- max_iter = cfg.SOLVER.MAX_ITER
-
- periodic_checkpointer = PeriodicCheckpointer(
- checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
- )
-
- writers = (
- [
- CommonMetricPrinter(max_iter),
- JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
- TensorboardXWriter(cfg.OUTPUT_DIR),
- ]
- if comm.is_main_process()
- else []
- )
-
- # compared to "train_net.py", we do not support accurate timing and
- # precise BN here, because they are not trivial to implement in a small training loop
- data_loader = build_detection_train_loader(cfg)
- logger.info("Starting training from iteration {}".format(start_iter))
- with EventStorage(start_iter) as storage:
- for data, iteration in zip(data_loader, range(start_iter, max_iter)):
- iteration = iteration + 1
- storage.step()
-
- loss_dict = model(data)
- losses = sum(loss_dict.values())
- assert torch.isfinite(losses).all(), loss_dict
-
- loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
- losses_reduced = sum(loss for loss in loss_dict_reduced.values())
- if comm.is_main_process():
- storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced)
-
- optimizer.zero_grad()
- losses.backward()
- optimizer.step()
- storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
- scheduler.step()
-
- if (
- cfg.TEST.EVAL_PERIOD > 0
- and iteration % cfg.TEST.EVAL_PERIOD == 0
- and iteration != max_iter
- ):
- do_test(cfg, model)
- # Compared to "train_net.py", the test results are not dumped to EventStorage
- comm.synchronize()
-
- if iteration - start_iter > 5 and (iteration % 20 == 0 or iteration == max_iter):
- for writer in writers:
- writer.write()
- periodic_checkpointer.step(iteration)
-
-
- def setup(args):
- """
- Create configs and perform basic setups.
- """
- cfg = get_cfg()
- cfg.merge_from_file(args.config_file)
- cfg.merge_from_list(args.opts)
- cfg.freeze()
- default_setup(
- cfg, args
- ) # if you don't like any of the default setup, write your own setup code
- return cfg
-
-
- def main(args):
- cfg = setup(args)
-
- model = build_model(cfg)
- logger.info("Model:\n{}".format(model))
- if args.eval_only:
- DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
- cfg.MODEL.WEIGHTS, resume=args.resume
- )
- return do_test(cfg, model)
-
- distributed = comm.get_world_size() > 1
- if distributed:
- model = DistributedDataParallel(
- <<<<<<< HEAD
- model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
- find_unused_parameters=True
- =======
- model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True
- >>>>>>> 7b936afd5b423c3188687d8b529a984bed528a87
- )
-
- do_train(cfg, model, resume=args.resume)
- return do_test(cfg, model)
-
-
- if __name__ == "__main__":
- args = default_argument_parser().parse_args()
- print("Command Line Args:", args)
- launch(
- main,
- args.num_gpus,
- num_machines=args.num_machines,
- machine_rank=args.machine_rank,
- dist_url=args.dist_url,
- args=(args,),
- )
|