|
- import os
- import imageio
- import cv2
- import numpy as np
- from warnings import warn
- from time import sleep
- import argparse
-
- from multiprocessing import Pool
- from multiprocessing import TimeoutError as MP_TimeoutError
-
- START = "START"
- FINISH = "FINISH"
- WARNING = "WARNING"
- FAIL = "FAIL"
-
-
- def boolean_string(s):
- if s.upper() not in {'FALSE', 'TRUE'}:
- raise ValueError('Not a valid boolean string')
- return s.upper() == 'TRUE'
-
-
- parser = argparse.ArgumentParser(description='Test')
- parser.add_argument('--input_path', default='', type=str,
- help='Root path of raw dataset.')
- parser.add_argument('--output_path', default='', type=str,
- help='Root path for output.')
- parser.add_argument('--log_file', default='./pretreatment.log', type=str,
- help='Log file path. Default: ./pretreatment.log')
- parser.add_argument('--log', default=False, type=boolean_string,
- help='If set as True, all logs will be saved. '
- 'Otherwise, only warnings and errors will be saved.'
- 'Default: False')
- parser.add_argument('--worker_num', default=1, type=int,
- help='How many subprocesses to use for data pretreatment. '
- 'Default: 1')
- opt = parser.parse_args()
-
- INPUT_PATH = opt.input_path
- OUTPUT_PATH = opt.output_path
- IF_LOG = opt.log
- LOG_PATH = opt.log_file
- WORKERS = opt.worker_num
-
- T_H = 64
- T_W = 64
-
-
- def log2str(pid, comment, logs):
- str_log = ''
- if type(logs) is str:
- logs = [logs]
- for log in logs:
- str_log += "# JOB %d : --%s-- %s\n" % (
- pid, comment, log)
- return str_log
-
-
- def log_print(pid, comment, logs):
- str_log = log2str(pid, comment, logs)
- if comment in [WARNING, FAIL]:
- with open(LOG_PATH, 'a') as log_f:
- log_f.write(str_log)
- if comment in [START, FINISH]:
- if pid % 500 != 0:
- return
- print(str_log, end='')
-
-
- def cut_img(img, seq_info, frame_name, pid):
- # A silhouette contains too little white pixels
- # might be not valid for identification.
- if img.sum() <= 10000:
- message = 'seq:%s, frame:%s, no data, %d.' % (
- '-'.join(seq_info), frame_name, img.sum())
- warn(message)
- log_print(pid, WARNING, message)
- return None
- # Get the top and bottom point
- y = img.sum(axis=1)
- y_top = (y != 0).argmax(axis=0)
- y_btm = (y != 0).cumsum(axis=0).argmax(axis=0)
- img = img[y_top:y_btm + 1, :]
- # As the height of a person is larger than the width,
- # use the height to calculate resize ratio.
- _r = img.shape[1] / img.shape[0]
- _t_w = int(T_H * _r)
- img = cv2.resize(img, (_t_w, T_H), interpolation=cv2.INTER_CUBIC)
- # Get the median of x axis and regard it as the x center of the person.
- sum_point = img.sum()
- sum_column = img.sum(axis=0).cumsum()
- x_center = -1
- for i in range(sum_column.size):
- if sum_column[i] > sum_point / 2:
- x_center = i
- break
- if x_center < 0:
- message = 'seq:%s, frame:%s, no center.' % (
- '-'.join(seq_info), frame_name)
- warn(message)
- log_print(pid, WARNING, message)
- return None
- h_T_W = int(T_W / 2)
- left = x_center - h_T_W
- right = x_center + h_T_W
- if left <= 0 or right >= img.shape[1]:
- left += h_T_W
- right += h_T_W
- _ = np.zeros((img.shape[0], h_T_W))
- img = np.concatenate([_, img, _], axis=1)
- img = img[:, left:right]
- return img.astype('uint8')
-
-
- def cut_pickle(seq_info, pid):
- seq_name = '-'.join(seq_info)
- log_print(pid, START, seq_name)
- seq_path = os.path.join(INPUT_PATH, *seq_info)
- out_dir = os.path.join(OUTPUT_PATH, *seq_info)
- frame_list = os.listdir(seq_path)
- frame_list.sort()
- count_frame = 0
- for _frame_name in frame_list:
- frame_path = os.path.join(seq_path, _frame_name)
- img = cv2.imread(frame_path)[:, :, 0]
- img = cut_img(img, seq_info, _frame_name, pid)
- if img is not None:
- # Save the cut img
- save_path = os.path.join(out_dir, _frame_name)
- imageio.imwrite(save_path, img)
- count_frame += 1
- # Warn if the sequence contains less than 5 frames
- if count_frame < 5:
- message = 'seq:%s, less than 5 valid data.' % (
- '-'.join(seq_info))
- warn(message)
- log_print(pid, WARNING, message)
-
- log_print(pid, FINISH,
- 'Contain %d valid frames. Saved to %s.'
- % (count_frame, out_dir))
-
-
- pool = Pool(WORKERS)
- results = list()
- pid = 0
-
- print('Pretreatment Start.\n'
- 'Input path: %s\n'
- 'Output path: %s\n'
- 'Log file: %s\n'
- 'Worker num: %d' % (
- INPUT_PATH, OUTPUT_PATH, LOG_PATH, WORKERS))
-
- id_list = os.listdir(INPUT_PATH)
- id_list.sort()
- # Walk the input path
- for _id in id_list:
- seq_type = os.listdir(os.path.join(INPUT_PATH, _id))
- seq_type.sort()
- for _seq_type in seq_type:
- view = os.listdir(os.path.join(INPUT_PATH, _id, _seq_type))
- view.sort()
- for _view in view:
- seq_info = [_id, _seq_type, _view]
- out_dir = os.path.join(OUTPUT_PATH, *seq_info)
- os.makedirs(out_dir)
- results.append(
- pool.apply_async(
- cut_pickle,
- args=(seq_info, pid)))
- sleep(0.02)
- pid += 1
-
- pool.close()
- unfinish = 1
- while unfinish > 0:
- unfinish = 0
- for i, res in enumerate(results):
- try:
- res.get(timeout=0.1)
- except Exception as e:
- if type(e) == MP_TimeoutError:
- unfinish += 1
- continue
- else:
- print('\n\n\nERROR OCCUR: PID ##%d##, ERRORTYPE: %s\n\n\n',
- i, type(e))
- raise e
- pool.join()
|