|
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
-
- import ast
- import collections
- import contextlib
- import logging
- import os
- import re
- import time
- import traceback
- from collections import OrderedDict
- from typing import Any, Dict, Optional, Union
- from random import randint
-
- import torch
- from fairseq.dataclass.configs import CheckpointConfig
- from fairseq.dataclass.utils import (
- convert_namespace_to_omegaconf,
- overwrite_args_by_name,
- )
- from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
- from fairseq.file_io import PathManager
- from fairseq.models import FairseqDecoder, FairseqEncoder
- from omegaconf import DictConfig, open_dict, OmegaConf
-
-
- logger = logging.getLogger(__name__)
-
-
- def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
- from fairseq import meters
-
- # only one worker should attempt to create the required dir
- if trainer.data_parallel_rank == 0:
- os.makedirs(cfg.save_dir, exist_ok=True)
-
- prev_best = getattr(save_checkpoint, "best", val_loss)
- if val_loss is not None:
- best_function = max if cfg.maximize_best_checkpoint_metric else min
- save_checkpoint.best = best_function(val_loss, prev_best)
-
- if cfg.no_save:
- return
-
- trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
-
- if not trainer.should_save_checkpoint_on_current_rank:
- if trainer.always_call_state_dict_during_save_checkpoint:
- trainer.state_dict()
- return
-
- write_timer = meters.StopwatchMeter()
- write_timer.start()
-
- epoch = epoch_itr.epoch
- end_of_epoch = epoch_itr.end_of_epoch()
- updates = trainer.get_num_updates()
-
- logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
-
- def is_better(a, b):
- return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
-
- suffix = trainer.checkpoint_suffix
- checkpoint_conds = collections.OrderedDict()
- checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
- end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
- )
- checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
- not end_of_epoch
- and cfg.save_interval_updates > 0
- and updates % cfg.save_interval_updates == 0
- )
- checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
- not hasattr(save_checkpoint, "best")
- or is_better(val_loss, save_checkpoint.best)
- )
- if val_loss is not None and cfg.keep_best_checkpoints > 0:
- worst_best = getattr(save_checkpoint, "best", None)
- chkpts = checkpoint_paths(
- cfg.save_dir,
- pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
- cfg.best_checkpoint_metric
- ),
- )
- if len(chkpts) > 0:
- p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
- worst_best = float(p.rsplit("_")[-1].replace(".pt", ""))
- # add random digits to resolve ties
- rand_sfx = randint(0, cfg.keep_best_checkpoints)
- checkpoint_conds[
- "checkpoint.best_{}_{:.3f}{}.pt".format(cfg.best_checkpoint_metric,
- val_loss, rand_sfx)
- ] = worst_best is None or is_better(val_loss, worst_best)
- checkpoint_conds[
- "checkpoint_last{}.pt".format(suffix)
- ] = not cfg.no_last_checkpoints
-
- extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
- if hasattr(save_checkpoint, "best"):
- extra_state.update({"best": save_checkpoint.best})
-
- checkpoints = [
- trainer.moe_filename(os.path.join(cfg.save_dir, fn)) for fn, cond in checkpoint_conds.items() if cond
- ]
-
- if len(checkpoints) > 0:
- trainer.save_checkpoint(checkpoints[0], extra_state)
- for cp in checkpoints[1:]:
- if cfg.write_checkpoints_asynchronously:
- # TODO[ioPath]: Need to implement a delayed asynchronous
- # file copying/moving feature.
- logger.warning(
- f"ioPath is not copying {checkpoints[0]} to {cp} "
- "since async write mode is on."
- )
- else:
- assert PathManager.copy(
- checkpoints[0], cp, overwrite=True
- ), f"Failed to copy {checkpoints[0]} to {cp}"
-
- write_timer.stop()
- logger.info(
- "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
- checkpoints[0], epoch, updates, val_loss, write_timer.sum
- )
- )
-
- if not end_of_epoch and cfg.keep_interval_updates > 0:
- # remove old checkpoints; checkpoints are sorted in descending order
- if cfg.keep_interval_updates_pattern == -1:
- checkpoints = checkpoint_paths(
- cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
- )
- else:
- checkpoints = checkpoint_paths(
- cfg.save_dir,
- pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
- keep_match=True,
- )
- checkpoints = [
- x[0]
- for x in checkpoints
- if x[1] % cfg.keep_interval_updates_pattern != 0
- ]
-
- for old_chk in checkpoints[cfg.keep_interval_updates :]:
- if os.path.lexists(old_chk):
- os.remove(old_chk)
- elif PathManager.exists(old_chk):
- PathManager.rm(old_chk)
-
- if cfg.keep_last_epochs > 0:
- # remove old epoch checkpoints; checkpoints are sorted in descending order
- checkpoints = checkpoint_paths(
- cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
- )
- for old_chk in checkpoints[cfg.keep_last_epochs :]:
- if os.path.lexists(old_chk):
- os.remove(old_chk)
-
- if cfg.keep_best_checkpoints > 0:
- # only keep the best N checkpoints according to validation metric
- checkpoints = checkpoint_paths(
- cfg.save_dir,
- pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
- cfg.best_checkpoint_metric, suffix
- ),
- )
- if not cfg.maximize_best_checkpoint_metric:
- checkpoints = checkpoints[::-1]
- for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
- if os.path.lexists(old_chk):
- os.remove(old_chk)
-
-
- def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
- """
- Load a checkpoint and restore the training iterator.
-
- *passthrough_args* will be passed through to
- ``trainer.get_train_iterator``.
- """
-
- reset_optimizer = cfg.reset_optimizer
- reset_lr_scheduler = cfg.reset_lr_scheduler
- optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
- reset_meters = cfg.reset_meters
- reset_dataloader = cfg.reset_dataloader
-
- if cfg.finetune_from_model is not None and (
- reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
- ):
- raise ValueError(
- "--finetune-from-model can not be set together with either --reset-optimizer"
- " or reset_lr_scheduler or reset_meters or reset_dataloader"
- )
-
- suffix = trainer.checkpoint_suffix
- if (
- cfg.restore_file == "checkpoint_last.pt"
- ): # default value of restore_file is 'checkpoint_last.pt'
- checkpoint_path = os.path.join(
- cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
- )
- first_launch = not PathManager.exists(checkpoint_path)
- if cfg.finetune_from_model is not None and first_launch:
- # if there is no last checkpoint to restore, start the finetune from pretrained model
- # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
- if PathManager.exists(cfg.finetune_from_model):
- checkpoint_path = cfg.finetune_from_model
- reset_optimizer = True
- reset_lr_scheduler = True
- reset_meters = True
- reset_dataloader = True
- logger.info(
- f"loading pretrained model from {checkpoint_path}: "
- "optimizer, lr scheduler, meters, dataloader will be reset"
- )
- else:
- raise ValueError(
- f"--funetune-from-model {cfg.finetune_from_model} does not exist"
- )
- elif suffix is not None:
- checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
- else:
- checkpoint_path = cfg.restore_file
-
- if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
- raise ValueError(
- "--finetune-from-model and --restore-file (non-default value) "
- "can not be specified together: " + str(cfg)
- )
-
- extra_state = trainer.load_checkpoint(
- trainer.moe_filename(checkpoint_path),
- reset_optimizer,
- reset_lr_scheduler,
- optimizer_overrides,
- reset_meters=reset_meters,
- )
-
- if (
- extra_state is not None
- and "best" in extra_state
- and not reset_optimizer
- and not reset_meters
- ):
- save_checkpoint.best = extra_state["best"]
-
- if extra_state is not None and not reset_dataloader:
- # restore iterator from checkpoint
- itr_state = extra_state["train_iterator"]
- epoch_itr = trainer.get_train_iterator(
- epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
- )
- epoch_itr.load_state_dict(itr_state)
- else:
- epoch_itr = trainer.get_train_iterator(
- epoch=1, load_dataset=True, **passthrough_args
- )
-
- trainer.lr_step(epoch_itr.epoch)
-
- return extra_state, epoch_itr
-
-
- def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
- """Loads a checkpoint to CPU (with upgrading for backward compatibility).
-
- If doing single-GPU training or if the checkpoint is only being loaded by at
- most one process on each node (current default behavior is for only rank 0
- to read the checkpoint from disk), load_on_all_ranks should be False to
- avoid errors from torch.distributed not having been initialized or
- torch.distributed.barrier() hanging.
-
- If all processes on each node may be loading the checkpoint
- simultaneously, load_on_all_ranks should be set to True to avoid I/O
- conflicts.
-
- There's currently no support for > 1 but < all processes loading the
- checkpoint on each node.
- """
- local_path = PathManager.get_local_path(path)
- # The locally cached file returned by get_local_path() may be stale for
- # remote files that are periodically updated/overwritten (ex:
- # checkpoint_last.pt) - so we remove the local copy, sync across processes
- # (if needed), and then download a fresh copy.
- if local_path != path and PathManager.path_requires_pathmanager(path):
- try:
- os.remove(local_path)
- except FileNotFoundError:
- # With potentially multiple processes removing the same file, the
- # file being missing is benign (missing_ok isn't available until
- # Python 3.8).
- pass
- if load_on_all_ranks:
- torch.distributed.barrier()
- local_path = PathManager.get_local_path(path)
-
- with open(local_path, "rb") as f:
- state = torch.load(f, map_location=torch.device("cpu"))
-
- if "args" in state and state["args"] is not None and arg_overrides is not None:
- args = state["args"]
- for arg_name, arg_val in arg_overrides.items():
- setattr(args, arg_name, arg_val)
-
- if "cfg" in state and state["cfg"] is not None:
-
- # hack to be able to set Namespace in dict config. this should be removed when we update to newer
- # omegaconf version that supports object flags, or when we migrate all existing models
- from omegaconf import _utils
-
- old_primitive = _utils.is_primitive_type
- _utils.is_primitive_type = lambda _: True
-
- state["cfg"] = OmegaConf.create(state["cfg"])
-
- _utils.is_primitive_type = old_primitive
- OmegaConf.set_struct(state["cfg"], True)
-
- if arg_overrides is not None:
- overwrite_args_by_name(state["cfg"], arg_overrides)
-
- state = _upgrade_state_dict(state)
- return state
-
-
- def load_model_ensemble(
- filenames,
- arg_overrides: Optional[Dict[str, Any]] = None,
- task=None,
- strict=True,
- suffix="",
- num_shards=1,
- state=None,
- ):
- """Loads an ensemble of models.
-
- Args:
- filenames (List[str]): checkpoint files to load
- arg_overrides (Dict[str,Any], optional): override model args that
- were used during model training
- task (fairseq.tasks.FairseqTask, optional): task to use for loading
- """
- assert not (
- strict and num_shards > 1
- ), "Cannot load state dict with strict=True and checkpoint shards > 1"
- ensemble, args, _task = load_model_ensemble_and_task(
- filenames,
- arg_overrides,
- task,
- strict,
- suffix,
- num_shards,
- state,
- )
- return ensemble, args
-
-
- def get_maybe_sharded_checkpoint_filename(
- filename: str, suffix: str, shard_idx: int, num_shards: int
- ) -> str:
- orig_filename = filename
- filename = filename.replace(".pt", suffix + ".pt")
- fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
- model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
- if PathManager.exists(fsdp_filename):
- return fsdp_filename
- elif num_shards > 1:
- return model_parallel_filename
- else:
- return filename
-
-
- def load_model_ensemble_and_task(
- filenames,
- arg_overrides: Optional[Dict[str, Any]] = None,
- task=None,
- strict=True,
- suffix="",
- num_shards=1,
- state=None,
- ):
- assert state is None or len(filenames) == 1
-
- from fairseq import tasks
-
- assert not (
- strict and num_shards > 1
- ), "Cannot load state dict with strict=True and checkpoint shards > 1"
- ensemble = []
- cfg = None
- for filename in filenames:
- orig_filename = filename
- model_shard_state = {"shard_weights": [], "shard_metadata": []}
- assert num_shards > 0
- st = time.time()
- for shard_idx in range(num_shards):
- filename = get_maybe_sharded_checkpoint_filename(
- orig_filename, suffix, shard_idx, num_shards
- )
-
- if not PathManager.exists(filename):
- raise IOError("Model file not found: {}".format(filename))
- if state is None:
- state = load_checkpoint_to_cpu(filename, arg_overrides)
- if "args" in state and state["args"] is not None:
- cfg = convert_namespace_to_omegaconf(state["args"])
- elif "cfg" in state and state["cfg"] is not None:
- cfg = state["cfg"]
- else:
- raise RuntimeError(
- f"Neither args nor cfg exist in state keys = {state.keys()}"
- )
-
- if task is None:
- task = tasks.setup_task(cfg.task)
-
- if "task_state" in state:
- task.load_state_dict(state["task_state"])
-
- if "fsdp_metadata" in state and num_shards > 1:
- model_shard_state["shard_weights"].append(state["model"])
- model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
- # check FSDP import before the code goes too far
- if not has_FSDP:
- raise ImportError(
- "Cannot find FullyShardedDataParallel. "
- "Please install fairscale with: pip install fairscale"
- )
- if shard_idx == num_shards - 1:
- consolidated_model_state = FSDP.consolidate_shard_weights(
- shard_weights=model_shard_state["shard_weights"],
- shard_metadata=model_shard_state["shard_metadata"],
- )
- model = task.build_model(cfg.model)
- model.load_state_dict(
- consolidated_model_state, strict=strict, model_cfg=cfg.model
- )
- else:
- # model parallel checkpoint or unsharded checkpoint
- model = task.build_model(cfg.model)
- # print(f"checkpoint_utils.py line 448 cfg.model={cfg.model}")
- # print(f"checkpoint_utils.py line 449 model:{model}")
- model.load_state_dict(
- state["model"], strict=strict, model_cfg=cfg.model
- )
-
- # reset state so it gets loaded for the next model in ensemble
- state = None
- if shard_idx % 10 == 0 and shard_idx > 0:
- elapsed = time.time() - st
- logger.info(f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard")
-
- # build model for ensemble
- ensemble.append(model)
- return ensemble, cfg, task
-
-
- def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
- """Retrieves all checkpoints found in `path` directory.
-
- Checkpoints are identified by matching filename to the specified pattern. If
- the pattern contains groups, the result will be sorted by the first group in
- descending order.
- """
- pt_regexp = re.compile(pattern)
- files = PathManager.ls(path)
-
- entries = []
- for i, f in enumerate(files):
- m = pt_regexp.fullmatch(f)
- if m is not None:
- idx = float(m.group(1)) if len(m.groups()) > 0 else i
- entries.append((idx, m.group(0)))
- if keep_match:
- return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
- else:
- return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
-
-
- def torch_persistent_save(obj, filename, async_write: bool = False):
- if async_write:
- with PathManager.opena(filename, "wb") as f:
- _torch_persistent_save(obj, f)
- else:
- if PathManager.supports_rename(filename):
- # do atomic save
- with PathManager.open(filename + ".tmp", "wb") as f:
- _torch_persistent_save(obj, f)
- PathManager.rename(filename + ".tmp", filename)
- else:
- # fallback to non-atomic save
- with PathManager.open(filename, "wb") as f:
- _torch_persistent_save(obj, f)
-
-
- def _torch_persistent_save(obj, f):
- if isinstance(f, str):
- with PathManager.open(f, "wb") as h:
- torch_persistent_save(obj, h)
- return
- for i in range(3):
- try:
- return torch.save(obj, f)
- except Exception:
- if i == 2:
- logger.error(traceback.format_exc())
-
-
- def _upgrade_state_dict(state):
- """Helper for upgrading old model checkpoints."""
-
- # add optimizer_history
- if "optimizer_history" not in state:
- state["optimizer_history"] = [
- {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
- ]
- state["last_optimizer_state"] = state["optimizer"]
- del state["optimizer"]
- del state["best_loss"]
- # move extra_state into sub-dictionary
- if "epoch" in state and "extra_state" not in state:
- state["extra_state"] = {
- "epoch": state["epoch"],
- "batch_offset": state["batch_offset"],
- "val_loss": state["val_loss"],
- }
- del state["epoch"]
- del state["batch_offset"]
- del state["val_loss"]
- # reduce optimizer history's memory usage (only keep the last state)
- if "optimizer" in state["optimizer_history"][-1]:
- state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
- for optim_hist in state["optimizer_history"]:
- del optim_hist["optimizer"]
- # record the optimizer class name
- if "optimizer_name" not in state["optimizer_history"][-1]:
- state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
- # move best_loss into lr_scheduler_state
- if "lr_scheduler_state" not in state["optimizer_history"][-1]:
- state["optimizer_history"][-1]["lr_scheduler_state"] = {
- "best": state["optimizer_history"][-1]["best_loss"]
- }
- del state["optimizer_history"][-1]["best_loss"]
- # keep track of number of updates
- if "num_updates" not in state["optimizer_history"][-1]:
- state["optimizer_history"][-1]["num_updates"] = 0
- # old model checkpoints may not have separate source/target positions
- if (
- "args" in state
- and hasattr(state["args"], "max_positions")
- and not hasattr(state["args"], "max_source_positions")
- ):
- state["args"].max_source_positions = state["args"].max_positions
- state["args"].max_target_positions = state["args"].max_positions
- # use stateful training data iterator
- if "train_iterator" not in state["extra_state"]:
- state["extra_state"]["train_iterator"] = {
- "epoch": state["extra_state"]["epoch"],
- "iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
- }
-
- # backward compatibility, cfg updates
- if "args" in state and state["args"] is not None:
- # default to translation task
- if not hasattr(state["args"], "task"):
- state["args"].task = "translation"
- # --raw-text and --lazy-load are deprecated
- if getattr(state["args"], "raw_text", False):
- state["args"].dataset_impl = "raw"
- elif getattr(state["args"], "lazy_load", False):
- state["args"].dataset_impl = "lazy"
- # epochs start at 1
- if state["extra_state"]["train_iterator"] is not None:
- state["extra_state"]["train_iterator"]["epoch"] = max(
- state["extra_state"]["train_iterator"].get("epoch", 1), 1
- )
- # --remove-bpe ==> --postprocess
- if hasattr(state["args"], "remove_bpe"):
- state["args"].post_process = state["args"].remove_bpe
- # --min-lr ==> --stop-min-lr
- if hasattr(state["args"], "min_lr"):
- state["args"].stop_min_lr = state["args"].min_lr
- del state["args"].min_lr
- # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
- if (
- hasattr(state["args"], "criterion")
- and state["args"].criterion in [
- "binary_cross_entropy",
- "kd_binary_cross_entropy",
- ]
- ):
- state["args"].criterion = "wav2vec"
- # remove log_keys if it's None (criteria will supply a default value of [])
- if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
- delattr(state["args"], "log_keys")
- # speech_pretraining => audio pretraining
- if (
- hasattr(state["args"], "task")
- and state["args"].task == "speech_pretraining"
- ):
- state["args"].task = "audio_pretraining"
- # audio_cpc => wav2vec
- if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
- state["args"].arch = "wav2vec"
- # convert legacy float learning rate to List[float]
- if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
- state["args"].lr = [state["args"].lr]
- # convert task data arg to a string instead of List[string]
- if (
- hasattr(state["args"], "data")
- and isinstance(state["args"].data, list)
- and len(state["args"].data) > 0
- ):
- state["args"].data = state["args"].data[0]
- # remove keys in state["args"] related to teacher-student learning
- for key in [
- "static_teachers",
- "static_teacher_weights",
- "dynamic_teachers",
- "dynamic_teacher_weights",
- ]:
- if key in state["args"]:
- delattr(state["args"], key)
-
- state["cfg"] = convert_namespace_to_omegaconf(state["args"])
-
- if "cfg" in state and state["cfg"] is not None:
- cfg = state["cfg"]
- with open_dict(cfg):
- # any upgrades for Hydra-based configs
- if (
- "task" in cfg
- and "eval_wer_config" in cfg.task
- and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
- ):
- cfg.task.eval_wer_config.print_alignment = "hard"
- if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
- cfg.generation.print_alignment = "hard"
- if (
- "model" in cfg
- and "w2v_args" in cfg.model
- and cfg.model.w2v_args is not None
- and (
- hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
- )
- and isinstance(
- cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
- )
- ):
- cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
-
- return state
-
-
- def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
- """Prune the given state_dict if desired for LayerDrop
- (https://arxiv.org/abs/1909.11556).
-
- Training with LayerDrop allows models to be robust to pruning at inference
- time. This function prunes state_dict to allow smaller models to be loaded
- from a larger model and re-maps the existing state_dict for this to occur.
-
- It's called by functions that load models from checkpoints and does not
- need to be called directly.
- """
- arch = None
- if model_cfg is not None:
- arch = (
- model_cfg._name
- if isinstance(model_cfg, DictConfig)
- else getattr(model_cfg, "arch", None)
- )
-
- if not model_cfg or arch is None or arch == "ptt_transformer":
- # args should not be none, but don't crash if it is.
- return state_dict
-
- encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
- decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
-
- if not encoder_layers_to_keep and not decoder_layers_to_keep:
- return state_dict
-
- # apply pruning
- logger.info(
- "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
- )
-
- def create_pruning_pass(layers_to_keep, layer_name):
- keep_layers = sorted(
- int(layer_string) for layer_string in layers_to_keep.split(",")
- )
- mapping_dict = {}
- for i in range(len(keep_layers)):
- mapping_dict[str(keep_layers[i])] = str(i)
-
- regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
- return {"substitution_regex": regex, "mapping_dict": mapping_dict}
-
- pruning_passes = []
- if encoder_layers_to_keep:
- pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
- if decoder_layers_to_keep:
- pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
-
- new_state_dict = {}
- for layer_name in state_dict.keys():
- match = re.search(r"\.layers\.(\d+)\.", layer_name)
- # if layer has no number in it, it is a supporting layer, such as an
- # embedding
- if not match:
- new_state_dict[layer_name] = state_dict[layer_name]
- continue
-
- # otherwise, layer should be pruned.
- original_layer_number = match.group(1)
- # figure out which mapping dict to replace from
- for pruning_pass in pruning_passes:
- if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
- "substitution_regex"
- ].search(layer_name):
- new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
- substitution_match = pruning_pass["substitution_regex"].search(
- layer_name
- )
- new_state_key = (
- layer_name[: substitution_match.start(1)]
- + new_layer_number
- + layer_name[substitution_match.end(1) :]
- )
- new_state_dict[new_state_key] = state_dict[layer_name]
-
- # Since layers are now pruned, *_layers_to_keep are no longer needed.
- # This is more of "It would make it work fix" rather than a proper fix.
- if isinstance(model_cfg, DictConfig):
- context = open_dict(model_cfg)
- else:
- context = contextlib.ExitStack()
- with context:
- if hasattr(model_cfg, "encoder_layers_to_keep"):
- model_cfg.encoder_layers_to_keep = None
- if hasattr(model_cfg, "decoder_layers_to_keep"):
- model_cfg.decoder_layers_to_keep = None
-
- return new_state_dict
-
-
- def load_pretrained_component_from_model(
- component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str
- ):
- """
- Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
- provided `component` object. If state_dict fails to load, there may be a
- mismatch in the architecture of the corresponding `component` found in the
- `checkpoint` file.
- """
- if not PathManager.exists(checkpoint):
- raise IOError("Model file not found: {}".format(checkpoint))
- state = load_checkpoint_to_cpu(checkpoint)
- if isinstance(component, FairseqEncoder):
- component_type = "encoder"
- elif isinstance(component, FairseqDecoder):
- component_type = "decoder"
- else:
- raise ValueError(
- "component to load must be either a FairseqEncoder or "
- "FairseqDecoder. Loading other component types are not supported."
- )
- component_state_dict = OrderedDict()
- for key in state["model"].keys():
- if key.startswith(component_type):
- # encoder.input_layers.0.0.weight --> input_layers.0.0.weight
- component_subkey = key[len(component_type) + 1 :]
- component_state_dict[component_subkey] = state["model"][key]
- component.load_state_dict(component_state_dict, strict=True)
- return component
-
-
- def verify_checkpoint_directory(save_dir: str) -> None:
- if not os.path.exists(save_dir):
- os.makedirs(save_dir, exist_ok=True)
- temp_file_path = os.path.join(save_dir, "dummy")
- try:
- with open(temp_file_path, "w"):
- pass
- except OSError as e:
- logger.warning(
- "Unable to access checkpoint save directory: {}".format(save_dir)
- )
- raise e
- else:
- os.remove(temp_file_path)
|