|
- import json
- import logging
- import os
- import signal
- import sys
- import time
- from concurrent.futures import ThreadPoolExecutor
- from datetime import timedelta
- from typing import Any, Dict, Optional
-
- from transformers import TrainerCallback
- from transformers.trainer_utils import has_length
-
-
-
- class LogCallback(TrainerCallback):
- def __init__(self) -> None:
- r"""
- Initializes a callback for logging training and evaluation status.
- """
- """ Progress """
- self.start_time = 0
- self.cur_steps = 0
- self.max_steps = 0
- self.elapsed_time = ""
- self.remaining_time = ""
- self.thread_pool: Optional["ThreadPoolExecutor"] = None
- """ Status """
- self.aborted = False
- self.do_train = False
-
-
-
- def _set_abort(self, signum, frame) -> None:
- self.aborted = True
-
- def _reset(self, max_steps: int = 0) -> None:
- self.start_time = time.time()
- self.cur_steps = 0
- self.max_steps = max_steps
- self.elapsed_time = ""
- self.remaining_time = ""
-
- def _timing(self, cur_steps: int) -> None:
- cur_time = time.time()
- elapsed_time = cur_time - self.start_time
- avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
- remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
- self.cur_steps = cur_steps
- self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
- self.remaining_time = str(timedelta(seconds=int(remaining_time)))
-
- def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
- pass
-
- def _create_thread_pool(self, output_dir: str) -> None:
- os.makedirs(output_dir, exist_ok=True)
- self.thread_pool = ThreadPoolExecutor(max_workers=1)
-
- def _close_thread_pool(self) -> None:
- if self.thread_pool is not None:
- self.thread_pool.shutdown(wait=True)
- self.thread_pool = None
-
- def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
- pass
-
- def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
- r"""
- Event called at the beginning of training.
- """
- if args.should_save:
- self.do_train = True
- self._reset(max_steps=state.max_steps)
- self._create_thread_pool(output_dir=args.output_dir)
-
- def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
- r"""
- Event called at the end of training.
- """
- self._close_thread_pool()
-
- def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
- r"""
- Event called at the end of an substep during gradient accumulation.
- """
- if self.aborted:
- control.should_epoch_stop = True
- control.should_training_stop = True
-
- def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
- r"""
- Event called at the end of a training step.
- """
- if self.aborted:
- control.should_epoch_stop = True
- control.should_training_stop = True
-
- def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
- r"""
- Event called after an evaluation phase.
- """
- if not self.do_train:
- self._close_thread_pool()
-
- def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
- r"""
- Event called after a successful prediction.
- """
- if not self.do_train:
- self._close_thread_pool()
-
- def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
- r"""
- Event called after logging the last logs.
- """
- if not args.should_save:
- return
-
- self._timing(cur_steps=state.global_step)
- logs = dict(
- current_steps=self.cur_steps,
- total_steps=self.max_steps,
- loss=state.log_history[-1].get("loss", None),
- eval_loss=state.log_history[-1].get("eval_loss", None),
- predict_loss=state.log_history[-1].get("predict_loss", None),
- reward=state.log_history[-1].get("reward", None),
- accuracy=state.log_history[-1].get("rewards/accuracies", None),
- learning_rate=state.log_history[-1].get("learning_rate", None),
- epoch=state.log_history[-1].get("epoch", None),
- percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
- elapsed_time=self.elapsed_time,
- remaining_time=self.remaining_time,
- throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)),
- total_tokens=state.num_input_tokens_seen,
- )
- logs = {k: v for k, v in logs.items() if v is not None}
- if all(key in logs for key in ["loss", "learning_rate", "epoch"]):
- print(
- "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
- logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"]
- )
- )
-
- if self.thread_pool is not None:
- self.thread_pool.submit(self._write_log, args.output_dir, logs)
-
- def on_prediction_step(
- self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
- ):
- r"""
- Event called after a prediction step.
- """
- if self.do_train:
- return
-
- if self.aborted:
- sys.exit(0)
-
- if not args.should_save:
- return
-
- eval_dataloader = kwargs.pop("eval_dataloader", None)
- if has_length(eval_dataloader):
- if self.max_steps == 0:
- self._reset(max_steps=len(eval_dataloader))
- self._create_thread_pool(output_dir=args.output_dir)
-
- self._timing(cur_steps=self.cur_steps + 1)
- if self.cur_steps % 5 == 0 and self.thread_pool is not None:
- logs = dict(
- current_steps=self.cur_steps,
- total_steps=self.max_steps,
- percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
- elapsed_time=self.elapsed_time,
- remaining_time=self.remaining_time,
- )
- self.thread_pool.submit(self._write_log, args.output_dir, logs)
|