#14 fix param load

Merged
msstudy merged 1 commits from ArthurZhao/bytetrack:master into master 1 year ago
  1. +17
    -17
      train.py

+ 17
- 17
train.py View File

@@ -279,23 +279,23 @@ def run_train():
config.val_ckpt = config.load_path
param_dict = load_checkpoint(config.pretrain_ckpt)
ema_param_dict = {}
# for param in param_dict:
# if "head_l.cls_preds" in param:
# continue
# if "head_m.cls_preds" in param:
# continue
# if "head_s.cls_preds" in param:
# continue
# if param.startswith("ema."):
# new_name = param.split("ema.")[1]
# data = param_dict[param]
# data.name = new_name
# ema_param_dict[new_name] = data
# if "moving_mean" in param:
# ema_param_dict[param] = param_dict[param]
# if "moving_variance" in param:
# ema_param_dict[param] = param_dict[param]
load_param_into_net(base_network, param_dict)
for param in param_dict:
if "head_l.cls_preds" in param:
continue
if "head_m.cls_preds" in param:
continue
if "head_s.cls_preds" in param:
continue
if param.startswith("ema."):
new_name = param.split("ema.")[1]
data = param_dict[param]
data.name = new_name
ema_param_dict[new_name] = data
if "moving_mean" in param:
ema_param_dict[param] = param_dict[param]
if "moving_variance" in param:
ema_param_dict[param] = param_dict[param]
load_param_into_net(base_network, ema_param_dict)
config.logger.info('load model %s success', config.val_ckpt)
if config.pretrained:
base_network = load_backbone(base_network, config.pretrained, config)


Loading…
Cancel
Save