#1 test_model

Closed
123455 wants to merge 6 commits from test_model into master
  1. +2
    -0
      .gitignore
  2. +1
    -1
      default_config.yaml
  3. +6
    -3
      src/util.py
  4. +14
    -0
      train.py

+ 2
- 0
.gitignore View File

@@ -0,0 +1,2 @@
*.ckpt
*.idea

+ 1
- 1
default_config.yaml View File

@@ -11,7 +11,7 @@ max_epoch: 285
total_epoch: 300
data_dir: "/home/work/user-job-dir/inputs/data/" # dataset base path
# last no data aug related
yolox_no_aug_ckpt: ""
yolox_no_aug_ckpt: "https://open-data.obs.cn-south-222.ai.pcl.cn:443/attachment/d/5/d51e19dd-9950-42a5-9f1e-baf13b97408b/285.zip?response-content-disposition=attachment%3B+filename%3D%22285.zip%22&AWSAccessKeyId=UJN8OQXLVBV0J9IHDGN9&Expires=1651131880&Signature=DX3Kii600T5RmxwGl2XobvqUTo4%3D"
need_profiler: 0
pretrained: ''
resume_yolox: ''


+ 6
- 3
src/util.py View File

@@ -419,8 +419,9 @@ class EvalCallBack(Callback):
self.train_network = train_network
self.detection = detection
self.logger = config.logger
self.start_epoch = config.start_epoch
self.start_epoch = (config.steps_per_epoch // config.log_interval) * config.start_epoch
self.interval = config.interval * config.steps_per_epoch // config.log_interval
# self.save_path = save_path
self.save_path = os.path.join(config.outputs_dir, 'ckpt_' + str(config.rank) + '/')
print("=============================self.save_path", self.save_path)

@@ -454,11 +455,13 @@ class EvalCallBack(Callback):
if results >= self.best_result:
self.best_result = results
self.best_epoch = cur_epoch
if not os.path.exists(self.save_path):
os.mkdir(self.save_path)
file_name = os.path.join(self.save_path, 'best.ckpt')
print("=====================file_name:", file_name)
if os.path.exists(file_name):
self.remove_ckpoint_file(file_name)
save_checkpoint(cb_param.train_network, file_name)
self.remove_ckpoint_file(file_name) # fixme debug
save_checkpoint(self.test_network, file_name)
self.logger.info("Best result %s at %s epoch" % (self.best_result, self.best_epoch))
self.logger.info(eval_print_str)
self.logger.info('Ending inference...')


+ 14
- 0
train.py View File

@@ -44,6 +44,7 @@ set_seed(42)
def set_default():
""" set default """
print('***********************data_url**************************', config.data_url, flush=True)

if config.enable_modelarts:
config.data_root = os.path.join(config.data_url, 'coco2017/train2017')
print('in default:',
@@ -279,6 +280,18 @@ def run_train():
if config.resume_yolox:
load_resume_params(config, network_ema)
if not config.data_aug:
# for the last no data aug ckpt preprocess
if config.enable_modelarts:
url = config.yolox_no_aug_ckpt
from urllib import request
import zipfile
print("downloading the ckpt file...")
request.urlretrieve(url, './285.zip')
zf = zipfile.ZipFile('./285.zip')
zf.extractall(path='./')
zf.close()
print('ckpt download done')
config.yolox_no_aug_ckpt = './285.ckpt'
if os.path.isfile(config.yolox_no_aug_ckpt): # Loading the resume checkpoint for the last no data aug epochs
param_dict = load_checkpoint(config.yolox_no_aug_ckpt)
if "learning_rate" in param_dict:
@@ -305,6 +318,7 @@ def run_train():
is_modelart=config.enable_modelarts,
per_print_times=config.log_interval, train_url=args_opt.train_url))
if config.run_eval:
config.logger.info('save_ckpt_path:', save_ckpt_path)
cb.append(
EvalCallBack(ds_test, test_network, network_ema, DetectionEngine(config), config, save_path=save_ckpt_path))
if config.need_profiler:


Loading…
Cancel
Save