|
- # Copyright (c) Open-MMLab. All rights reserved.
- import logging
- import os.path as osp
- import warnings
- from abc import ABCMeta, abstractmethod
-
- import torch
- from torch.optim import Optimizer
-
- import mmcv
- from ..parallel import is_module_wrapper
- from .checkpoint import load_checkpoint
- from .dist_utils import get_dist_info
- from .hooks import HOOKS, Hook, IterTimerHook
- from .log_buffer import LogBuffer
- from .priority import get_priority
- from .utils import get_time_str
-
-
- class BaseRunner(metaclass=ABCMeta):
- """The base class of Runner, a training helper for PyTorch.
-
- All subclasses should implement the following APIs:
-
- - ``run()``
- - ``train()``
- - ``val()``
- - ``save_checkpoint()``
-
- Args:
- model (:obj:`torch.nn.Module`): The model to be run.
- batch_processor (callable): A callable method that process a data
- batch. The interface of this method should be
- `batch_processor(model, data, train_mode) -> dict`
- optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
- optimizer (in most cases) or a dict of optimizers (in models that
- requires more than one optimizer, e.g., GAN).
- work_dir (str, optional): The working directory to save checkpoints
- and logs. Defaults to None.
- logger (:obj:`logging.Logger`): Logger used during training.
- Defaults to None. (The default value is just for backward
- compatibility)
- meta (dict | None): A dict records some import information such as
- environment info and seed, which will be logged in logger hook.
- Defaults to None.
- max_epochs (int, optional): Total training epochs.
- max_iters (int, optional): Total training iterations.
- """
-
- def __init__(self,
- model,
- batch_processor=None,
- optimizer=None,
- work_dir=None,
- logger=None,
- meta=None,
- max_iters=None,
- max_epochs=None):
- if batch_processor is not None:
- if not callable(batch_processor):
- raise TypeError('batch_processor must be callable, '
- f'but got {type(batch_processor)}')
- warnings.warn('batch_processor is deprecated, please implement '
- 'train_step() and val_step() in the model instead.')
- # raise an error is `batch_processor` is not None and
- # `model.train_step()` exists.
- if is_module_wrapper(model):
- _model = model.module
- else:
- _model = model
- if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
- raise RuntimeError(
- 'batch_processor and model.train_step()/model.val_step() '
- 'cannot be both available.')
- else:
- assert hasattr(model, 'train_step')
-
- # check the type of `optimizer`
- if isinstance(optimizer, dict):
- for name, optim in optimizer.items():
- if not isinstance(optim, Optimizer):
- raise TypeError(
- f'optimizer must be a dict of torch.optim.Optimizers, '
- f'but optimizer["{name}"] is a {type(optim)}')
- elif not isinstance(optimizer, Optimizer) and optimizer is not None:
- raise TypeError(
- f'optimizer must be a torch.optim.Optimizer object '
- f'or dict or None, but got {type(optimizer)}')
-
- # check the type of `logger`
- if not isinstance(logger, logging.Logger):
- raise TypeError(f'logger must be a logging.Logger object, '
- f'but got {type(logger)}')
-
- # check the type of `meta`
- if meta is not None and not isinstance(meta, dict):
- raise TypeError(
- f'meta must be a dict or None, but got {type(meta)}')
-
- self.model = model
- self.batch_processor = batch_processor
- self.optimizer = optimizer
- self.logger = logger
- self.meta = meta
-
- # create work_dir
- if mmcv.is_str(work_dir):
- self.work_dir = osp.abspath(work_dir)
- mmcv.mkdir_or_exist(self.work_dir)
- elif work_dir is None:
- self.work_dir = None
- else:
- raise TypeError('"work_dir" must be a str or None')
-
- # get model name from the model class
- if hasattr(self.model, 'module'):
- self._model_name = self.model.module.__class__.__name__
- else:
- self._model_name = self.model.__class__.__name__
-
- self._rank, self._world_size = get_dist_info()
- self.timestamp = get_time_str()
- self.mode = None
- self._hooks = []
- self._epoch = 0
- self._iter = 0
- self._inner_iter = 0
-
- if max_epochs is not None and max_iters is not None:
- raise ValueError(
- 'Only one of `max_epochs` or `max_iters` can be set.')
-
- self._max_epochs = max_epochs
- self._max_iters = max_iters
- # TODO: Redesign LogBuffer, it is not flexible and elegant enough
- self.log_buffer = LogBuffer()
-
- @property
- def model_name(self):
- """str: Name of the model, usually the module class name."""
- return self._model_name
-
- @property
- def rank(self):
- """int: Rank of current process. (distributed training)"""
- return self._rank
-
- @property
- def world_size(self):
- """int: Number of processes participating in the job.
- (distributed training)"""
- return self._world_size
-
- @property
- def hooks(self):
- """list[:obj:`Hook`]: A list of registered hooks."""
- return self._hooks
-
- @property
- def epoch(self):
- """int: Current epoch."""
- return self._epoch
-
- @property
- def iter(self):
- """int: Current iteration."""
- return self._iter
-
- @property
- def inner_iter(self):
- """int: Iteration in an epoch."""
- return self._inner_iter
-
- @property
- def max_epochs(self):
- """int: Maximum training epochs."""
- return self._max_epochs
-
- @property
- def max_iters(self):
- """int: Maximum training iterations."""
- return self._max_iters
-
- @abstractmethod
- def train(self):
- pass
-
- @abstractmethod
- def val(self):
- pass
-
- @abstractmethod
- def run(self, data_loaders, workflow, **kwargs):
- pass
-
- @abstractmethod
- def save_checkpoint(self,
- out_dir,
- filename_tmpl,
- save_optimizer=True,
- meta=None,
- create_symlink=True):
- pass
-
- def current_lr(self):
- """Get current learning rates.
-
- Returns:
- list[float] | dict[str, list[float]]: Current learning rates of all
- param groups. If the runner has a dict of optimizers, this
- method will return a dict.
- """
- if isinstance(self.optimizer, torch.optim.Optimizer):
- lr = [group['lr'] for group in self.optimizer.param_groups]
- elif isinstance(self.optimizer, dict):
- lr = dict()
- for name, optim in self.optimizer.items():
- lr[name] = [group['lr'] for group in optim.param_groups]
- else:
- raise RuntimeError(
- 'lr is not applicable because optimizer does not exist.')
- return lr
-
- def current_momentum(self):
- """Get current momentums.
-
- Returns:
- list[float] | dict[str, list[float]]: Current momentums of all
- param groups. If the runner has a dict of optimizers, this
- method will return a dict.
- """
-
- def _get_momentum(optimizer):
- momentums = []
- for group in optimizer.param_groups:
- if 'momentum' in group.keys():
- momentums.append(group['momentum'])
- elif 'betas' in group.keys():
- momentums.append(group['betas'][0])
- else:
- momentums.append(0)
- return momentums
-
- if self.optimizer is None:
- raise RuntimeError(
- 'momentum is not applicable because optimizer does not exist.')
- elif isinstance(self.optimizer, torch.optim.Optimizer):
- momentums = _get_momentum(self.optimizer)
- elif isinstance(self.optimizer, dict):
- momentums = dict()
- for name, optim in self.optimizer.items():
- momentums[name] = _get_momentum(optim)
- return momentums
-
- def register_hook(self, hook, priority='NORMAL'):
- """Register a hook into the hook list.
-
- The hook will be inserted into a priority queue, with the specified
- priority (See :class:`Priority` for details of priorities).
- For hooks with the same priority, they will be triggered in the same
- order as they are registered.
-
- Args:
- hook (:obj:`Hook`): The hook to be registered.
- priority (int or str or :obj:`Priority`): Hook priority.
- Lower value means higher priority.
- """
- assert isinstance(hook, Hook)
- if hasattr(hook, 'priority'):
- raise ValueError('"priority" is a reserved attribute for hooks')
- priority = get_priority(priority)
- hook.priority = priority
- # insert the hook to a sorted list
- inserted = False
- for i in range(len(self._hooks) - 1, -1, -1):
- if priority >= self._hooks[i].priority:
- self._hooks.insert(i + 1, hook)
- inserted = True
- break
- if not inserted:
- self._hooks.insert(0, hook)
-
- def register_hook_from_cfg(self, hook_cfg):
- """Register a hook from its cfg.
-
- Args:
- hook_cfg (dict): Hook config. It should have at least keys 'type'
- and 'priority' indicating its type and priority.
-
- Notes:
- The specific hook class to register should not use 'type' and
- 'priority' arguments during initialization.
- """
- hook_cfg = hook_cfg.copy()
- priority = hook_cfg.pop('priority', 'NORMAL')
- hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
- self.register_hook(hook, priority=priority)
-
- def call_hook(self, fn_name):
- """Call all hooks.
-
- Args:
- fn_name (str): The function name in each hook to be called, such as
- "before_train_epoch".
- """
- for hook in self._hooks:
- getattr(hook, fn_name)(self)
-
- def load_checkpoint(self, filename, map_location='cpu', strict=False):
- self.logger.info('load checkpoint from %s', filename)
- return load_checkpoint(self.model, filename, map_location, strict,
- self.logger)
-
- def resume(self,
- checkpoint,
- resume_optimizer=True,
- map_location='default'):
- if map_location == 'default':
- if torch.cuda.is_available():
- device_id = torch.cuda.current_device()
- checkpoint = self.load_checkpoint(
- checkpoint,
- map_location=lambda storage, loc: storage.cuda(device_id))
- else:
- checkpoint = self.load_checkpoint(checkpoint)
- else:
- checkpoint = self.load_checkpoint(
- checkpoint, map_location=map_location)
-
- self._epoch = checkpoint['meta']['epoch']
- self._iter = checkpoint['meta']['iter']
- if 'optimizer' in checkpoint and resume_optimizer:
- if isinstance(self.optimizer, Optimizer):
- self.optimizer.load_state_dict(checkpoint['optimizer'])
- elif isinstance(self.optimizer, dict):
- for k in self.optimizer.keys():
- self.optimizer[k].load_state_dict(
- checkpoint['optimizer'][k])
- else:
- raise TypeError(
- 'Optimizer should be dict or torch.optim.Optimizer '
- f'but got {type(self.optimizer)}')
-
- self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
-
- def register_lr_hook(self, lr_config):
- if isinstance(lr_config, dict):
- assert 'policy' in lr_config
- policy_type = lr_config.pop('policy')
- # If the type of policy is all in lower case, e.g., 'cyclic',
- # then its first letter will be capitalized, e.g., to be 'Cyclic'.
- # This is for the convenient usage of Lr updater.
- # Since this is not applicable for `
- # CosineAnnealingLrUpdater`,
- # the string will not be changed if it contains capital letters.
- if policy_type == policy_type.lower():
- policy_type = policy_type.title()
- hook_type = policy_type + 'LrUpdaterHook'
- lr_config['type'] = hook_type
- hook = mmcv.build_from_cfg(lr_config, HOOKS)
- else:
- hook = lr_config
- self.register_hook(hook)
-
- def register_momentum_hook(self, momentum_config):
- if momentum_config is None:
- return
- if isinstance(momentum_config, dict):
- assert 'policy' in momentum_config
- policy_type = momentum_config.pop('policy')
- # If the type of policy is all in lower case, e.g., 'cyclic',
- # then its first letter will be capitalized, e.g., to be 'Cyclic'.
- # This is for the convenient usage of momentum updater.
- # Since this is not applicable for
- # `CosineAnnealingMomentumUpdater`,
- # the string will not be changed if it contains capital letters.
- if policy_type == policy_type.lower():
- policy_type = policy_type.title()
- hook_type = policy_type + 'MomentumUpdaterHook'
- momentum_config['type'] = hook_type
- hook = mmcv.build_from_cfg(momentum_config, HOOKS)
- else:
- hook = momentum_config
- self.register_hook(hook)
-
- def register_optimizer_hook(self, optimizer_config):
- if optimizer_config is None:
- return
- if isinstance(optimizer_config, dict):
- optimizer_config.setdefault('type', 'OptimizerHook')
- hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
- else:
- hook = optimizer_config
- self.register_hook(hook)
-
- def register_checkpoint_hook(self, checkpoint_config):
- if checkpoint_config is None:
- return
- if isinstance(checkpoint_config, dict):
- checkpoint_config.setdefault('type', 'CheckpointHook')
- hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
- else:
- hook = checkpoint_config
- self.register_hook(hook)
-
- def register_logger_hooks(self, log_config):
- if log_config is None:
- return
- log_interval = log_config['interval']
- for info in log_config['hooks']:
- logger_hook = mmcv.build_from_cfg(
- info, HOOKS, default_args=dict(interval=log_interval))
- self.register_hook(logger_hook, priority='VERY_LOW')
-
- def register_training_hooks(self,
- lr_config,
- optimizer_config=None,
- checkpoint_config=None,
- log_config=None,
- momentum_config=None):
- """Register default hooks for training.
-
- Default hooks include:
-
- - LrUpdaterHook
- - MomentumUpdaterHook
- - OptimizerStepperHook
- - CheckpointSaverHook
- - IterTimerHook
- - LoggerHook(s)
- """
- self.register_lr_hook(lr_config)
- self.register_momentum_hook(momentum_config)
- self.register_optimizer_hook(optimizer_config)
- self.register_checkpoint_hook(checkpoint_config)
- self.register_hook(IterTimerHook())
- self.register_logger_hooks(log_config)
|