You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

222 lines
8.8 KiB

  1. """A multi-thread tool to crop large images to sub-images for faster IO.
  2. (This preprocessing code is copied and modified from official implement:
  3. https://github.com/open-mmlab/mmsr/tree/master/codes/data_scripts)"""
  4. import os
  5. import os.path as osp
  6. import sys
  7. from multiprocessing import Pool
  8. import numpy as np
  9. import cv2
  10. from PIL import Image
  11. import time
  12. from shutil import get_terminal_size
  13. sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
  14. def main():
  15. mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs)
  16. opt = {}
  17. opt['n_thread'] = 20
  18. opt['compression_level'] = 3 # 3 is the default value in cv2
  19. # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
  20. # compression time. If read raw images during training, use 0 for faster IO speed.
  21. if mode == 'single':
  22. opt['input_folder'] = './data/DIV2K/DIV2K_train_HR'
  23. opt['save_folder'] = './data/DIV2K/DIV2K800_sub'
  24. opt['crop_sz'] = 128 # the size of each sub-image
  25. opt['step'] = 64 # step of the sliding crop window
  26. opt['thres_sz'] = 12 # size threshold
  27. extract_signle(opt)
  28. elif mode == 'pair':
  29. GT_folder = './data/DIV2K/DIV2K_train_HR'
  30. LR_folder = './data/DIV2K/DIV2K_train_LR_X4'
  31. save_GT_folder = './data/DIV2K/DIV2K800_sub'
  32. save_LR_folder = './data/DIV2K/DIV2K800_sub_LRx4'
  33. scale_ratio = 4
  34. crop_sz = 480 # the size of each sub-image (GT)
  35. step = 240 # step of the sliding crop window (GT)
  36. thres_sz = 48 # size threshold
  37. ########################################################################
  38. # check that all the GT and LR images have correct scale ratio
  39. img_GT_list = _get_paths_from_images(GT_folder)
  40. img_LR_list = _get_paths_from_images(LR_folder)
  41. assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.'
  42. for path_GT, path_LR in zip(img_GT_list, img_LR_list):
  43. img_GT = Image.open(path_GT)
  44. img_LR = Image.open(path_LR)
  45. w_GT, h_GT = img_GT.size
  46. w_LR, h_LR = img_LR.size
  47. assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
  48. w_GT, scale_ratio, w_LR, path_GT)
  49. assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
  50. w_GT, scale_ratio, w_LR, path_GT)
  51. # check crop size, step and threshold size
  52. assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format(
  53. scale_ratio)
  54. assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio)
  55. assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format(
  56. scale_ratio)
  57. print('process GT...')
  58. opt['input_folder'] = GT_folder
  59. opt['save_folder'] = save_GT_folder
  60. opt['crop_sz'] = crop_sz
  61. opt['step'] = step
  62. opt['thres_sz'] = thres_sz
  63. extract_signle(opt)
  64. print('process LR...')
  65. opt['input_folder'] = LR_folder
  66. opt['save_folder'] = save_LR_folder
  67. opt['crop_sz'] = crop_sz // scale_ratio
  68. opt['step'] = step // scale_ratio
  69. opt['thres_sz'] = thres_sz // scale_ratio
  70. extract_signle(opt)
  71. assert len(_get_paths_from_images(save_GT_folder)) == len(
  72. _get_paths_from_images(
  73. save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.'
  74. else:
  75. raise ValueError('Wrong mode.')
  76. def extract_signle(opt):
  77. input_folder = opt['input_folder']
  78. save_folder = opt['save_folder']
  79. if not osp.exists(save_folder):
  80. os.makedirs(save_folder)
  81. print('mkdir [{:s}] ...'.format(save_folder))
  82. else:
  83. print('Folder [{:s}] already exists. Exit...'.format(save_folder))
  84. sys.exit(1)
  85. img_list = _get_paths_from_images(input_folder)
  86. def update(arg):
  87. pbar.update(arg)
  88. pbar = ProgressBar(len(img_list))
  89. pool = Pool(opt['n_thread'])
  90. for path in img_list:
  91. pool.apply_async(worker, args=(path, opt), callback=update)
  92. pool.close()
  93. pool.join()
  94. print('All subprocesses done.')
  95. def worker(path, opt):
  96. crop_sz = opt['crop_sz']
  97. step = opt['step']
  98. thres_sz = opt['thres_sz']
  99. img_name = osp.basename(path)
  100. img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
  101. n_channels = len(img.shape)
  102. if n_channels == 2:
  103. h, w = img.shape
  104. elif n_channels == 3:
  105. h, w, c = img.shape
  106. else:
  107. raise ValueError('Wrong image shape - {}'.format(n_channels))
  108. h_space = np.arange(0, h - crop_sz + 1, step)
  109. if h - (h_space[-1] + crop_sz) > thres_sz:
  110. h_space = np.append(h_space, h - crop_sz)
  111. w_space = np.arange(0, w - crop_sz + 1, step)
  112. if w - (w_space[-1] + crop_sz) > thres_sz:
  113. w_space = np.append(w_space, w - crop_sz)
  114. index = 0
  115. for x in h_space:
  116. for y in w_space:
  117. index += 1
  118. if n_channels == 2:
  119. crop_img = img[x:x + crop_sz, y:y + crop_sz]
  120. else:
  121. crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
  122. crop_img = np.ascontiguousarray(crop_img)
  123. cv2.imwrite(
  124. osp.join(opt['save_folder'],
  125. img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img,
  126. [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
  127. return 'Processing {:s} ...'.format(img_name)
  128. # ##############
  129. # ### Utils ####
  130. # ##############
  131. class ProgressBar(object):
  132. '''A progress bar which can print the progress
  133. modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
  134. '''
  135. def __init__(self, task_num=0, bar_width=50, start=True):
  136. self.task_num = task_num
  137. max_bar_width = self._get_max_bar_width()
  138. self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
  139. self.completed = 0
  140. if start:
  141. self.start()
  142. def _get_max_bar_width(self):
  143. terminal_width, _ = get_terminal_size()
  144. max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
  145. if max_bar_width < 10:
  146. print('terminal width is too small ({}), please consider widen the terminal for better '
  147. 'progressbar visualization'.format(terminal_width))
  148. max_bar_width = 10
  149. return max_bar_width
  150. def start(self):
  151. if self.task_num > 0:
  152. sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
  153. ' ' * self.bar_width, self.task_num, 'Start...'))
  154. else:
  155. sys.stdout.write('completed: 0, elapsed: 0s')
  156. sys.stdout.flush()
  157. self.start_time = time.time()
  158. def update(self, msg='In progress...'):
  159. self.completed += 1
  160. elapsed = time.time() - self.start_time + 1e-9
  161. fps = self.completed / elapsed
  162. if self.task_num > 0:
  163. percentage = self.completed / float(self.task_num)
  164. eta = int(elapsed * (1 - percentage) / percentage + 0.5)
  165. mark_width = int(self.bar_width * percentage)
  166. bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
  167. sys.stdout.write('\033[2F') # cursor up 2 lines
  168. sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
  169. sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
  170. bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
  171. else:
  172. sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
  173. self.completed, int(elapsed + 0.5), fps))
  174. sys.stdout.flush()
  175. # ###################
  176. # ### Data Utils ####
  177. # ###################
  178. IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
  179. def is_image_file(filename):
  180. return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
  181. def _get_paths_from_images(path):
  182. """get image path list from image folder"""
  183. assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
  184. images = []
  185. for dirpath, _, fnames in sorted(os.walk(path)):
  186. for fname in sorted(fnames):
  187. if is_image_file(fname):
  188. img_path = os.path.join(dirpath, fname)
  189. images.append(img_path)
  190. assert images, '{:s} has no valid image file'.format(path)
  191. return images
  192. if __name__ == '__main__':
  193. main()