|
- #!/usr/bin/env python3
-
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
-
- """Execute various operations (train, test, time, etc.) on a classification model."""
-
- import argparse
- import sys
- from numpy.core.einsumfunc import _parse_possible_contraction
- import torch
-
- import xcom.core.builders as builders
- import xcom.core.config as config
- import xcom.core.distributed as dist
- import xcom.core.net as net
- import xcom.core.checkpoint as cp
-
- import xcom.core.trainer as trainer
- import xcom.models.scaler as scaler
- from xcom.core.config import cfg
- import xcom.core.logging as logging
- import xcom.core.meters as meters
- import xcom.datasets.loader as data_loader
-
- from xcom.pruning import L1FilterPruner
- from xcom.compression.utils.counter import count_flops_params
-
- logger = logging.get_logger(__name__)
-
- def parse_args():
- """Parse command line options (mode and config)."""
- parser = argparse.ArgumentParser(description="Run a model.")
- help_s, choices = "Run mode", ["info", "train", "test", "time", "scale"]
- # parser.add_argument("--mode", help=help_s, choices=choices, required=True, type=str, default="test")
- help_s = "Config file location"
- parser.add_argument("--cfg", help=help_s, required=True, type=str)
- help_s = "See xcom/core/config.py for all options"
- parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER)
- if len(sys.argv) == 1:
- parser.print_help()
- sys.exit(1)
- return parser.parse_args()
-
- def main():
- """Execute operation (train, test, time, etc.)."""
- args = parse_args()
- config.load_cfg(args.cfg)
- cfg.merge_from_list(args.opts)
- config.assert_cfg()
- cfg.freeze()
- trainer.setup_env()
- model = builders.build_model()
- # Transfer the model to the current GPU device
- cur_device = torch.cuda.current_device()
- model = model.cuda(device=cur_device)
- config_list=[{'sparsity':0.8, 'op_types':['Conv2d']}]
- pruner = L1FilterPruner(model,config_list)
- pruner.compress()
-
- test_weights = trainer.get_weights_file(cfg.TEST.WEIGHTS)
- cp.load_checkpoint(test_weights, model)
- # im=cfg.TRAIN.IM_SIZE if mode == "train" else 0
- im=cfg.TEST.IM_SIZE
- assert im
- logger.info("Model:\n{}".format(model)) if cfg.VERBOSE else ()
- # Log model complexity
- dummy_input=torch.randn([1,3,im,im]).to(cur_device)
- flops,params,_ = count_flops_params(model,dummy_input,verbose=False)
- logger.info(f"FLOPs: {flops}, params: {params}")
- print("complexity:", net.complexity(builders.get_model()))
-
- # Evaluate model
- test_loader = data_loader.construct_test_loader()
- test_meter = meters.TestMeter(len(test_loader))
- trainer.test_epoch(test_loader, model, test_meter, 0)
-
-
- flops,params,_ = count_flops_params(model,dummy_input,verbose=False)
- logger.info(f"FLOPs: {flops}, params: {params}")
- print("complexity:", net.complexity(builders.get_model()))
- # logger.info("Model:\n{}".format(model)) if cfg.VERBOSE else ()
- trainer.test_epoch(test_loader, model, test_meter, 0)
- # trainer.finetune(model)
- # trainer.test_epoch(test_loader, model, test_meter, 0)
-
- if __name__ == "__main__":
- main()
|