|
- import paddle
- import paddle.nn as nn
- from typing import Optional, Mapping, Union, Tuple
-
- from .base import Merger, Compose
-
-
- class SegmentationTTAWrapper(nn.Layer):
- """Wrap Paddle nn.Layer (segmentation model) with test time augmentation transforms
-
- Args:
- model (paddle.nn.Layer): segmentation model with single input and single output
- (.forward(x) should return either Tensor or Mapping[str, Tensor])
- transforms (patta.Compose): composition of test time transforms
- merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen
- output_mask_key (str): if model output is `dict`, specify which key belong to `mask`
- """
-
- def __init__(
- self,
- model: nn.Layer,
- transforms: Compose,
- merge_mode: str = "mean",
- output_mask_key: Optional[str] = None,
- ):
- super().__init__()
- self.model = model
- self.transforms = transforms
- self.merge_mode = merge_mode
- self.output_key = output_mask_key
-
- def forward(
- self, image, *args
- ):
- merger = Merger(type=self.merge_mode, n=len(self.transforms))
-
- for transformer in self.transforms:
- augmented_image = transformer.augment_image(image)
- augmented_output = self.model(augmented_image, *args)[0]
- if self.output_key is not None:
- augmented_output = augmented_output[self.output_key]
- deaugmented_output = transformer.deaugment_mask(augmented_output)
- merger.append(deaugmented_output)
-
- result = merger.result
- if self.output_key is not None:
- result = {self.output_key: result}
-
- return result
-
-
- class ClassificationTTAWrapper(nn.Layer):
- """Wrap Paddle nn.Layer (classification model) with test time augmentation transforms
-
- Args:
- model (paddle.nn.Layer): classification model with single input and single output
- (.forward(x) should return either Tensor or Mapping[str, Tensor])
- transforms (patta.Compose): composition of test time transforms
- merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen
- output_label_key (str): if model output is `dict`, specify which key belong to `label`
- """
-
- def __init__(
- self,
- model: nn.Layer,
- transforms: Compose,
- merge_mode: str = "mean",
- output_label_key: Optional[str] = None,
- ):
- super().__init__()
- self.model = model
- self.transforms = transforms
- self.merge_mode = merge_mode
- self.output_key = output_label_key
-
- def forward(
- self, image, *args
- ):
- merger = Merger(type=self.merge_mode, n=len(self.transforms))
-
- for transformer in self.transforms:
- augmented_image = transformer.augment_image(image)
- augmented_output = self.model(augmented_image, *args)
- if self.output_key is not None:
- augmented_output = augmented_output[self.output_key]
- deaugmented_output = transformer.deaugment_label(augmented_output)
- merger.append(deaugmented_output)
-
- result = merger.result
- if self.output_key is not None:
- result = {self.output_key: result}
-
- return result
-
-
- class KeypointsTTAWrapper(nn.Layer):
- """Wrap Paddle nn.Layer (keypoints model) with test time augmentation transforms
-
- Args:
- model (paddle.nn.Layer): keypoints model with single input and single output
- in format [x1,y1, x2, y2, ..., xn, yn]
- (.forward(x) should return either Tensor or Mapping[str, Tensor])
- transforms (patta.Compose): composition of test time transforms
- merge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpen
- output_keypoints_key (str): if model output is `dict`, specify which key belong to `label`
- scaled (bool): True if model return x, y scaled values in [0, 1], else False
-
- """
-
- def __init__(
- self,
- model: nn.Layer,
- transforms: Compose,
- merge_mode: str = "mean",
- output_keypoints_key: Optional[str] = None,
- scaled: bool = False,
- ):
- super().__init__()
- self.model = model
- self.transforms = transforms
- self.merge_mode = merge_mode
- self.output_key = output_keypoints_key
- self.scaled = scaled
-
- def forward(
- self, image, *args
- ):
- merger = Merger(type=self.merge_mode, n=len(self.transforms))
- size = image.size()
- batch_size, image_height, image_width = size[0], size[2], size[3]
-
- for transformer in self.transforms:
- augmented_image = transformer.augment_image(image)
- augmented_output = self.model(augmented_image, *args)
-
- if self.output_key is not None:
- augmented_output = augmented_output[self.output_key]
-
- augmented_output = augmented_output.reshape(batch_size, -1, 2)
- if not self.scaled:
- augmented_output[..., 0] /= image_width
- augmented_output[..., 1] /= image_height
-
- deaugmented_output = transformer.deaugment_keypoints(augmented_output)
- merger.append(deaugmented_output)
-
- result = merger.result
-
- if not self.scaled:
- result[..., 0] *= image_width
- result[..., 1] *= image_height
- result = result.reshape(batch_size, -1)
-
- if self.output_key is not None:
- result = {self.output_key: result}
-
- return result
|