|
- # -*- coding: utf-8 -*-
- """
- @Time : 2022-05-10 10:30
- @Author : Zhuxb
-
- """
- import json
- import os
- import contextlib
- import copy
- import functools
- import logging
- import os
- import sys
- import time
- import threading
- from typing import List
- import colorlog
- from colorama import Fore
- import paddle
-
- class DefaultLogger(object):
- def __init__(self, args):
- log_config = {
- 'DEBUG': {
- 'level': 10,
- 'color': 'purple'
- },
- 'INFO': {
- 'level': 20,
- 'color': 'green'
- },
- 'TRAIN': {
- 'level': 21,
- 'color': 'cyan'
- },
- 'EVAL': {
- 'level': 22,
- 'color': 'blue'
- },
- 'WARNING': {
- 'level': 30,
- 'color': 'yellow'
- },
- 'ERROR': {
- 'level': 40,
- 'color': 'red'
- },
- 'CRITICAL': {
- 'level': 50,
- 'color': 'bold_red'
- }
- }
-
- self.args = args
- # 默认写入logs文件夹
- os.makedirs(f'{args.output_path}/logs', exist_ok=True)
- self.logger = logging.getLogger()
- for key, conf in log_config.items():
- logging.addLevelName(conf['level'], key)
- self.__dict__[key] = functools.partial(self.__call__, conf['level'])
- self.__dict__[key.lower()] = functools.partial(self.__call__, conf['level'])
-
- # 创建一个handler,用于写入日志文件
- time_format = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime(time.time()))
- self.file_format = logging.Formatter('[%(asctime)s-15s] - [%(levelname)8s] - %(message)s')
- self.file_handler = logging.FileHandler(f'{args.output_path}/logs/output_{time_format}.log', encoding='utf-8')
- self.file_handler.setFormatter(self.file_format)
- self.logger.addHandler(self.file_handler)
-
- # 再创建一个handler,用于输出到控制台
- self.console_format = colorlog.ColoredFormatter(
- '%(log_color)s[%(asctime)-15s] [%(levelname)8s]%(reset)s - %(message)s',
- log_colors={
- key: conf['color']
- for key, conf in log_config.items()
- })
- self.console_handler = logging.StreamHandler()
- self.console_handler.setFormatter(self.console_format)
- self.logger.addHandler(self.console_handler)
-
- self.logger.setLevel(logging.DEBUG)
- self.logger.propagate = False
- self._is_enable = True
-
- def disable(self):
- self._is_enable = False
-
- def enable(self):
- self._is_enable = True
-
- @property
- def is_enable(self) -> bool:
- return self._is_enable
-
- def __call__(self, log_level: str, msg: str):
- if not self.is_enable:
- return
- trainer_num = paddle.distributed.get_world_size()
- trainer_id = paddle.distributed.get_rank()
-
- if trainer_num > 1:
- self.logger.log(log_level, f"Current rank [{trainer_id}] - {msg}")
- else:
- self.logger.log(log_level, msg)
- if self.args.envir == "cloud":
- print(msg)
-
- @contextlib.contextmanager
- def use_terminator(self, terminator: str):
- old_terminator = self.handler.terminator
- self.handler.terminator = terminator
- yield
- self.handler.terminator = old_terminator
-
- @contextlib.contextmanager
- def processing(self, msg: str, interval: float=0.1):
- '''
- Continuously print a progress bar with rotating special effects.
-
- Args:
- msg(str): Message to be printed.
- interval(float): Rotation interval. Default to 0.1.
- '''
- end = False
-
- def _printer():
- index = 0
- flags = ['\\', '|', '/', '-']
- while not end:
- flag = flags[index % len(flags)]
- with self.use_terminator('\r'):
- self.info('{}: {}'.format(msg, flag))
- time.sleep(interval)
- index += 1
-
- t = threading.Thread(target=_printer)
- t.start()
- yield
- end = True
-
- def write_args(self, args):
- args_format = "--------------args------------------- \n{}\n-----------------------------------------"
- # self.logger.info(args_format.format(json.dumps(args, indent=4)))
- self.logger.info(args_format.format(args))
|