|
- import functools
- import warnings
- from collections import abc
- from inspect import getfullargspec
-
- import numpy as np
- import torch
- import torch.nn as nn
-
- from .dist_utils import allreduce_grads as _allreduce_grads
-
-
- def cast_tensor_type(inputs, src_type, dst_type):
- """Recursively convert Tensor in inputs from src_type to dst_type.
-
- Args:
- inputs: Inputs that to be casted.
- src_type (torch.dtype): Source type..
- dst_type (torch.dtype): Destination type.
-
- Returns:
- The same type with inputs, but all contained Tensors have been cast.
- """
- if isinstance(inputs, torch.Tensor):
- return inputs.to(dst_type)
- elif isinstance(inputs, str):
- return inputs
- elif isinstance(inputs, np.ndarray):
- return inputs
- elif isinstance(inputs, abc.Mapping):
- return type(inputs)({
- k: cast_tensor_type(v, src_type, dst_type)
- for k, v in inputs.items()
- })
- elif isinstance(inputs, abc.Iterable):
- return type(inputs)(
- cast_tensor_type(item, src_type, dst_type) for item in inputs)
- else:
- return inputs
-
-
- def auto_fp16(apply_to=None, out_fp32=False):
- """Decorator to enable fp16 training automatically.
-
- This decorator is useful when you write custom modules and want to support
- mixed precision training. If inputs arguments are fp32 tensors, they will
- be converted to fp16 automatically. Arguments other than fp32 tensors are
- ignored.
-
- Args:
- apply_to (Iterable, optional): The argument names to be converted.
- `None` indicates all arguments.
- out_fp32 (bool): Whether to convert the output back to fp32.
-
- Example:
-
- >>> import torch.nn as nn
- >>> class MyModule1(nn.Module):
- >>>
- >>> # Convert x and y to fp16
- >>> @auto_fp16()
- >>> def forward(self, x, y):
- >>> pass
-
- >>> import torch.nn as nn
- >>> class MyModule2(nn.Module):
- >>>
- >>> # convert pred to fp16
- >>> @auto_fp16(apply_to=('pred', ))
- >>> def do_something(self, pred, others):
- >>> pass
- """
-
- def auto_fp16_wrapper(old_func):
-
- @functools.wraps(old_func)
- def new_func(*args, **kwargs):
- # check if the module has set the attribute `fp16_enabled`, if not,
- # just fallback to the original method.
- if not isinstance(args[0], torch.nn.Module):
- raise TypeError('@auto_fp16 can only be used to decorate the '
- 'method of nn.Module')
- if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
- return old_func(*args, **kwargs)
- # get the arg spec of the decorated method
- args_info = getfullargspec(old_func)
- # get the argument names to be casted
- args_to_cast = args_info.args if apply_to is None else apply_to
- # convert the args that need to be processed
- new_args = []
- # NOTE: default args are not taken into consideration
- if args:
- arg_names = args_info.args[:len(args)]
- for i, arg_name in enumerate(arg_names):
- if arg_name in args_to_cast:
- new_args.append(
- cast_tensor_type(args[i], torch.float, torch.half))
- else:
- new_args.append(args[i])
- # convert the kwargs that need to be processed
- new_kwargs = {}
- if kwargs:
- for arg_name, arg_value in kwargs.items():
- if arg_name in args_to_cast:
- new_kwargs[arg_name] = cast_tensor_type(
- arg_value, torch.float, torch.half)
- else:
- new_kwargs[arg_name] = arg_value
- # apply converted arguments to the decorated method
- output = old_func(*new_args, **new_kwargs)
- # cast the results back to fp32 if necessary
- if out_fp32:
- output = cast_tensor_type(output, torch.half, torch.float)
- return output
-
- return new_func
-
- return auto_fp16_wrapper
-
-
- def force_fp32(apply_to=None, out_fp16=False):
- """Decorator to convert input arguments to fp32 in force.
-
- This decorator is useful when you write custom modules and want to support
- mixed precision training. If there are some inputs that must be processed
- in fp32 mode, then this decorator can handle it. If inputs arguments are
- fp16 tensors, they will be converted to fp32 automatically. Arguments other
- than fp16 tensors are ignored.
-
- Args:
- apply_to (Iterable, optional): The argument names to be converted.
- `None` indicates all arguments.
- out_fp16 (bool): Whether to convert the output back to fp16.
-
- Example:
-
- >>> import torch.nn as nn
- >>> class MyModule1(nn.Module):
- >>>
- >>> # Convert x and y to fp32
- >>> @force_fp32()
- >>> def loss(self, x, y):
- >>> pass
-
- >>> import torch.nn as nn
- >>> class MyModule2(nn.Module):
- >>>
- >>> # convert pred to fp32
- >>> @force_fp32(apply_to=('pred', ))
- >>> def post_process(self, pred, others):
- >>> pass
- """
-
- def force_fp32_wrapper(old_func):
-
- @functools.wraps(old_func)
- def new_func(*args, **kwargs):
- # check if the module has set the attribute `fp16_enabled`, if not,
- # just fallback to the original method.
- if not isinstance(args[0], torch.nn.Module):
- raise TypeError('@force_fp32 can only be used to decorate the '
- 'method of nn.Module')
- if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
- return old_func(*args, **kwargs)
- # get the arg spec of the decorated method
- args_info = getfullargspec(old_func)
- # get the argument names to be casted
- args_to_cast = args_info.args if apply_to is None else apply_to
- # convert the args that need to be processed
- new_args = []
- if args:
- arg_names = args_info.args[:len(args)]
- for i, arg_name in enumerate(arg_names):
- if arg_name in args_to_cast:
- new_args.append(
- cast_tensor_type(args[i], torch.half, torch.float))
- else:
- new_args.append(args[i])
- # convert the kwargs that need to be processed
- new_kwargs = dict()
- if kwargs:
- for arg_name, arg_value in kwargs.items():
- if arg_name in args_to_cast:
- new_kwargs[arg_name] = cast_tensor_type(
- arg_value, torch.half, torch.float)
- else:
- new_kwargs[arg_name] = arg_value
- # apply converted arguments to the decorated method
- output = old_func(*new_args, **new_kwargs)
- # cast the results back to fp32 if necessary
- if out_fp16:
- output = cast_tensor_type(output, torch.float, torch.half)
- return output
-
- return new_func
-
- return force_fp32_wrapper
-
-
- def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
- warnings.warning(
- '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
- 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads')
- _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
-
-
- def wrap_fp16_model(model):
- """Wrap the FP32 model to FP16.
-
- 1. Convert FP32 model to FP16.
- 2. Remain some necessary layers to be FP32, e.g., normalization layers.
-
- Args:
- model (nn.Module): Model in FP32.
- """
- # convert model to fp16
- model.half()
- # patch the normalization layers to make it work in fp32 mode
- patch_norm_fp32(model)
- # set `fp16_enabled` flag
- for m in model.modules():
- if hasattr(m, 'fp16_enabled'):
- m.fp16_enabled = True
-
-
- def patch_norm_fp32(module):
- """Recursively convert normalization layers from FP16 to FP32.
-
- Args:
- module (nn.Module): The modules to be converted in FP16.
-
- Returns:
- nn.Module: The converted module, the normalization layers have been
- converted to FP32.
- """
- if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
- module.float()
- if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3':
- module.forward = patch_forward_method(module.forward, torch.half,
- torch.float)
- for child in module.children():
- patch_norm_fp32(child)
- return module
-
-
- def patch_forward_method(func, src_type, dst_type, convert_output=True):
- """Patch the forward method of a module.
-
- Args:
- func (callable): The original forward method.
- src_type (torch.dtype): Type of input arguments to be converted from.
- dst_type (torch.dtype): Type of input arguments to be converted to.
- convert_output (bool): Whether to convert the output back to src_type.
-
- Returns:
- callable: The patched forward method.
- """
-
- def new_forward(*args, **kwargs):
- output = func(*cast_tensor_type(args, src_type, dst_type),
- **cast_tensor_type(kwargs, src_type, dst_type))
- if convert_output:
- output = cast_tensor_type(output, dst_type, src_type)
- return output
-
- return new_forward
-
-
- class LossScaler:
- """Class that manages loss scaling in mixed precision training which
- supports both dynamic or static mode.
-
- The implementation refers to
- https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py.
- Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling.
- It's important to understand how :class:`LossScaler` operates.
- Loss scaling is designed to combat the problem of underflowing
- gradients encountered at long times when training fp16 networks.
- Dynamic loss scaling begins by attempting a very high loss
- scale. Ironically, this may result in OVERflowing gradients.
- If overflowing gradients are encountered, :class:`FP16_Optimizer` then
- skips the update step for this particular iteration/minibatch,
- and :class:`LossScaler` adjusts the loss scale to a lower value.
- If a certain number of iterations occur without overflowing gradients
- detected,:class:`LossScaler` increases the loss scale once more.
- In this way :class:`LossScaler` attempts to "ride the edge" of always
- using the highest loss scale possible without incurring overflow.
-
- Args:
- init_scale (float): Initial loss scale value, default: 2**32.
- scale_factor (float): Factor used when adjusting the loss scale.
- Default: 2.
- mode (str): Loss scaling mode. 'dynamic' or 'static'
- scale_window (int): Number of consecutive iterations without an
- overflow to wait before increasing the loss scale. Default: 1000.
- """
-
- def __init__(self,
- init_scale=2**32,
- mode='dynamic',
- scale_factor=2.,
- scale_window=1000):
- self.cur_scale = init_scale
- self.cur_iter = 0
- assert mode in ('dynamic',
- 'static'), 'mode can only be dynamic or static'
- self.mode = mode
- self.last_overflow_iter = -1
- self.scale_factor = scale_factor
- self.scale_window = scale_window
-
- def has_overflow(self, params):
- """Check if params contain overflow."""
- if self.mode != 'dynamic':
- return False
- for p in params:
- if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data):
- return True
- return False
-
- def _has_inf_or_nan(x):
- """Check if params contain NaN."""
- try:
- cpu_sum = float(x.float().sum())
- except RuntimeError as instance:
- if 'value cannot be converted' not in instance.args[0]:
- raise
- return True
- else:
- if cpu_sum == float('inf') or cpu_sum == -float('inf') \
- or cpu_sum != cpu_sum:
- return True
- return False
-
- def update_scale(self, overflow):
- """update the current loss scale value when overflow happens."""
- if self.mode != 'dynamic':
- return
- if overflow:
- self.cur_scale = max(self.cur_scale / self.scale_factor, 1)
- self.last_overflow_iter = self.cur_iter
- else:
- if (self.cur_iter - self.last_overflow_iter) % \
- self.scale_window == 0:
- self.cur_scale *= self.scale_factor
- self.cur_iter += 1
-
- @property
- def loss_scale(self):
- return self.cur_scale
|