|
- #!/usr/bin/env python
- # coding: utf-8
- #
- # Author: Kazuto Nakashima
- # URL: http://kazuto1011.github.io
- # Created: 2017-10-30
-
- import random
-
- import cv2
- import numpy as np
- import torch
- from PIL import Image
- from torch.utils import data
-
-
- class _BaseDataset(data.Dataset):
- """
- Base dataset class
- """
-
- def __init__(
- self,
- root,
- split,
- ignore_label,
- mean_bgr,
- augment=True,
- base_size=None,
- crop_size=321,
- scales=(1.0),
- flip=True,
- ):
- self.root = root
- self.split = split
- self.ignore_label = ignore_label
- self.mean_bgr = np.array(mean_bgr)
- self.augment = augment
- self.base_size = base_size
- self.crop_size = crop_size
- self.scales = scales
- self.flip = flip
- self.files = []
- self._set_files()
-
- cv2.setNumThreads(0)
-
- def _set_files(self):
- """
- Create a file path/image id list.
- """
- raise NotImplementedError()
-
- def _load_data(self, image_id):
- """
- Load the image and label in numpy.ndarray
- """
- raise NotImplementedError()
-
- def _augmentation(self, image, label):
- # Scaling
- h, w = label.shape
- if self.base_size:
- if h > w:
- h, w = (self.base_size, int(self.base_size * w / h))
- else:
- h, w = (int(self.base_size * h / w), self.base_size)
- scale_factor = random.choice(self.scales)
- h, w = (int(h * scale_factor), int(w * scale_factor))
- image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR)
- label = Image.fromarray(label).resize((w, h), resample=Image.NEAREST)
- label = np.asarray(label, dtype=np.int64)
-
- # Padding to fit for crop_size
- h, w = label.shape
- pad_h = max(self.crop_size - h, 0)
- pad_w = max(self.crop_size - w, 0)
- pad_kwargs = {
- "top": 0,
- "bottom": pad_h,
- "left": 0,
- "right": pad_w,
- "borderType": cv2.BORDER_CONSTANT,
- }
- if pad_h > 0 or pad_w > 0:
- image = cv2.copyMakeBorder(image, value=self.mean_bgr, **pad_kwargs)
- label = cv2.copyMakeBorder(label, value=self.ignore_label, **pad_kwargs)
-
- # Cropping
- h, w = label.shape
- start_h = random.randint(0, h - self.crop_size)
- start_w = random.randint(0, w - self.crop_size)
- end_h = start_h + self.crop_size
- end_w = start_w + self.crop_size
- image = image[start_h:end_h, start_w:end_w]
- label = label[start_h:end_h, start_w:end_w]
-
- if self.flip:
- # Random flipping
- if random.random() < 0.5:
- image = np.fliplr(image).copy() # HWC
- label = np.fliplr(label).copy() # HW
- return image, label
-
- def __getitem__(self, index):
- image_id, image, label = self._load_data(index)
- if self.augment:
- image, label = self._augmentation(image, label)
- # Mean subtraction
- image -= self.mean_bgr
- # HWC -> CHW
- image = image.transpose(2, 0, 1)
- return image_id, image.astype(np.float32), label.astype(np.int64)
-
- def __len__(self):
- return len(self.files)
-
- def __repr__(self):
- fmt_str = "Dataset: " + self.__class__.__name__ + "\n"
- fmt_str += " # data: {}\n".format(self.__len__())
- fmt_str += " Split: {}\n".format(self.split)
- fmt_str += " Root: {}".format(self.root)
- return fmt_str
|