|
- # Copyright (c) Open-MMLab. All rights reserved.
- import sys
- from collections.abc import Iterable
- from multiprocessing import Pool
- from shutil import get_terminal_size
-
- from .timer import Timer
-
-
- class ProgressBar:
- """A progress bar which can print the progress."""
-
- def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
- self.task_num = task_num
- self.bar_width = bar_width
- self.completed = 0
- self.file = file
- if start:
- self.start()
-
- @property
- def terminal_width(self):
- width, _ = get_terminal_size()
- return width
-
- def start(self):
- if self.task_num > 0:
- self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
- 'elapsed: 0s, ETA:')
- else:
- self.file.write('completed: 0, elapsed: 0s')
- self.file.flush()
- self.timer = Timer()
-
- def update(self, num_tasks=1):
- assert num_tasks > 0
- self.completed += num_tasks
- elapsed = self.timer.since_start()
- if elapsed > 0:
- fps = self.completed / elapsed
- else:
- fps = float('inf')
- if self.task_num > 0:
- percentage = self.completed / float(self.task_num)
- eta = int(elapsed * (1 - percentage) / percentage + 0.5)
- msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
- f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
- f'ETA: {eta:5}s'
-
- bar_width = min(self.bar_width,
- int(self.terminal_width - len(msg)) + 2,
- int(self.terminal_width * 0.6))
- bar_width = max(2, bar_width)
- mark_width = int(bar_width * percentage)
- bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
- self.file.write(msg.format(bar_chars))
- else:
- self.file.write(
- f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
- f' {fps:.1f} tasks/s')
- self.file.flush()
-
-
- def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
- """Track the progress of tasks execution with a progress bar.
-
- Tasks are done with a simple for-loop.
-
- Args:
- func (callable): The function to be applied to each task.
- tasks (list or tuple[Iterable, int]): A list of tasks or
- (tasks, total num).
- bar_width (int): Width of progress bar.
-
- Returns:
- list: The task results.
- """
- if isinstance(tasks, tuple):
- assert len(tasks) == 2
- assert isinstance(tasks[0], Iterable)
- assert isinstance(tasks[1], int)
- task_num = tasks[1]
- tasks = tasks[0]
- elif isinstance(tasks, Iterable):
- task_num = len(tasks)
- else:
- raise TypeError(
- '"tasks" must be an iterable object or a (iterator, int) tuple')
- prog_bar = ProgressBar(task_num, bar_width, file=file)
- results = []
- for task in tasks:
- results.append(func(task, **kwargs))
- prog_bar.update()
- prog_bar.file.write('\n')
- return results
-
-
- def init_pool(process_num, initializer=None, initargs=None):
- if initializer is None:
- return Pool(process_num)
- elif initargs is None:
- return Pool(process_num, initializer)
- else:
- if not isinstance(initargs, tuple):
- raise TypeError('"initargs" must be a tuple')
- return Pool(process_num, initializer, initargs)
-
-
- def track_parallel_progress(func,
- tasks,
- nproc,
- initializer=None,
- initargs=None,
- bar_width=50,
- chunksize=1,
- skip_first=False,
- keep_order=True,
- file=sys.stdout):
- """Track the progress of parallel task execution with a progress bar.
-
- The built-in :mod:`multiprocessing` module is used for process pools and
- tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
-
- Args:
- func (callable): The function to be applied to each task.
- tasks (list or tuple[Iterable, int]): A list of tasks or
- (tasks, total num).
- nproc (int): Process (worker) number.
- initializer (None or callable): Refer to :class:`multiprocessing.Pool`
- for details.
- initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
- details.
- chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
- bar_width (int): Width of progress bar.
- skip_first (bool): Whether to skip the first sample for each worker
- when estimating fps, since the initialization step may takes
- longer.
- keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
- :func:`Pool.imap_unordered` is used.
-
- Returns:
- list: The task results.
- """
- if isinstance(tasks, tuple):
- assert len(tasks) == 2
- assert isinstance(tasks[0], Iterable)
- assert isinstance(tasks[1], int)
- task_num = tasks[1]
- tasks = tasks[0]
- elif isinstance(tasks, Iterable):
- task_num = len(tasks)
- else:
- raise TypeError(
- '"tasks" must be an iterable object or a (iterator, int) tuple')
- pool = init_pool(nproc, initializer, initargs)
- start = not skip_first
- task_num -= nproc * chunksize * int(skip_first)
- prog_bar = ProgressBar(task_num, bar_width, start, file=file)
- results = []
- if keep_order:
- gen = pool.imap(func, tasks, chunksize)
- else:
- gen = pool.imap_unordered(func, tasks, chunksize)
- for result in gen:
- results.append(result)
- if skip_first:
- if len(results) < nproc * chunksize:
- continue
- elif len(results) == nproc * chunksize:
- prog_bar.start()
- continue
- prog_bar.update()
- prog_bar.file.write('\n')
- pool.close()
- pool.join()
- return results
-
-
- def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
- """Track the progress of tasks iteration or enumeration with a progress
- bar.
-
- Tasks are yielded with a simple for-loop.
-
- Args:
- tasks (list or tuple[Iterable, int]): A list of tasks or
- (tasks, total num).
- bar_width (int): Width of progress bar.
-
- Yields:
- list: The task results.
- """
- if isinstance(tasks, tuple):
- assert len(tasks) == 2
- assert isinstance(tasks[0], Iterable)
- assert isinstance(tasks[1], int)
- task_num = tasks[1]
- tasks = tasks[0]
- elif isinstance(tasks, Iterable):
- task_num = len(tasks)
- else:
- raise TypeError(
- '"tasks" must be an iterable object or a (iterator, int) tuple')
- prog_bar = ProgressBar(task_num, bar_width, file=file)
- for task in tasks:
- yield task
- prog_bar.update()
- prog_bar.file.write('\n')
|