|
- import torch.utils.data as data
- import os
- from PIL import Image
- import numpy as np
- import scipy.io as scio
- import torch
-
- # from skimage import morphology, io
-
- IMG_EXTENSIONS = [
- '.jpg', '.JPG', '.jpeg', '.JPEG',
- '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
- ]
-
-
- def is_image_file(filename):
- return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
-
-
- def img_loader(path, num_channels):
- if num_channels == 1:
- if ('.mat' in path):
- img = scio.loadmat(path)['inst_map']
- img = Image.fromarray(img.astype(np.uint8))
- elif ('.npy' in path):
- img = np.load(path)
- img = Image.fromarray(img.astype(np.uint8))
- else:
- img = Image.open(path)
- else:
- if ('.mat' in path):
- img = scio.loadmat(path)['inst_map']
- elif ('.npy' in path):
- img = np.load(path)
- img = Image.fromarray(img.astype(np.uint8))
- else:
- img = Image.open(path).convert('RGB')
-
- return img
-
-
- # get the image list pairs
- def get_imgs_list(dir_list, post_fix=None):
- """
- :param dir_list: [img1_dir, img2_dir, ...]
- :param post_fix: e.g. ['label.png', 'weight.png',...]
- :return: e.g. [(img1.ext, img1_label.png, img1_weight.png), ...]
- """
- img_list = []
- if len(dir_list) == 0:
- return img_list
- if len(dir_list) != len(post_fix) + 1:
- raise (RuntimeError('Should specify the postfix of each img type except the first input.'))
-
- img_filename_list = [os.listdir(dir_list[i]) for i in range(len(dir_list))]
-
- for img in img_filename_list[0]:
- if not is_image_file(img):
- continue
- img1_name = os.path.splitext(img)[0]
- item = [os.path.join(dir_list[0], img), ]
- for i in range(1, len(img_filename_list)):
- img_name = '{:s}_{:s}'.format(img1_name, post_fix[i - 1])
- if img_name in img_filename_list[i]:
- img_path = os.path.join(dir_list[i], img_name)
- item.append(img_path)
-
- if len(item) == len(dir_list):
- img_list.append(tuple(item))
-
- return img_list
-
-
- # dataset that supports one input image, one target image, and one weight map (optional)
- class DataFolder(data.Dataset):
- def __init__(self, dir_list, post_fix, num_channels, data_transform=None, loader=img_loader):
- super(DataFolder, self).__init__()
- if len(dir_list) != len(post_fix) + 1:
- raise (RuntimeError('Length of dir_list is different from length of post_fix + 1.'))
- if len(dir_list) != len(num_channels):
- raise (RuntimeError('Length of dir_list is different from length of num_channels.'))
-
- self.img_list = get_imgs_list(dir_list, post_fix)
- if len(self.img_list) == 0:
- raise (RuntimeError('Found 0 image pairs in given directories.'))
-
- self.data_transform = data_transform
- self.num_channels = num_channels
- self.loader = loader
-
- def __getitem__(self, index):
- img_paths = self.img_list[index]
-
- sample = [self.loader(img_paths[i], self.num_channels[i]) for i in range(len(img_paths))]
-
- if self.data_transform is not None:
- sample_tensor = self.data_transform(sample)
-
- while (len(torch.unique(sample_tensor[2])) <= 1): # sample[2].detach().cpu().numpy()
- if self.data_transform is not None:
- sample_tensor = self.data_transform(sample)
-
- return sample_tensor
-
- def __len__(self):
- return len(self.img_list)
|