|
|
@@ -279,23 +279,23 @@ def run_train(): |
|
|
|
config.val_ckpt = config.load_path |
|
|
|
param_dict = load_checkpoint(config.pretrain_ckpt) |
|
|
|
ema_param_dict = {} |
|
|
|
# for param in param_dict: |
|
|
|
# if "head_l.cls_preds" in param: |
|
|
|
# continue |
|
|
|
# if "head_m.cls_preds" in param: |
|
|
|
# continue |
|
|
|
# if "head_s.cls_preds" in param: |
|
|
|
# continue |
|
|
|
# if param.startswith("ema."): |
|
|
|
# new_name = param.split("ema.")[1] |
|
|
|
# data = param_dict[param] |
|
|
|
# data.name = new_name |
|
|
|
# ema_param_dict[new_name] = data |
|
|
|
# if "moving_mean" in param: |
|
|
|
# ema_param_dict[param] = param_dict[param] |
|
|
|
# if "moving_variance" in param: |
|
|
|
# ema_param_dict[param] = param_dict[param] |
|
|
|
load_param_into_net(base_network, param_dict) |
|
|
|
for param in param_dict: |
|
|
|
if "head_l.cls_preds" in param: |
|
|
|
continue |
|
|
|
if "head_m.cls_preds" in param: |
|
|
|
continue |
|
|
|
if "head_s.cls_preds" in param: |
|
|
|
continue |
|
|
|
if param.startswith("ema."): |
|
|
|
new_name = param.split("ema.")[1] |
|
|
|
data = param_dict[param] |
|
|
|
data.name = new_name |
|
|
|
ema_param_dict[new_name] = data |
|
|
|
if "moving_mean" in param: |
|
|
|
ema_param_dict[param] = param_dict[param] |
|
|
|
if "moving_variance" in param: |
|
|
|
ema_param_dict[param] = param_dict[param] |
|
|
|
load_param_into_net(base_network, ema_param_dict) |
|
|
|
config.logger.info('load model %s success', config.val_ckpt) |
|
|
|
if config.pretrained: |
|
|
|
base_network = load_backbone(base_network, config.pretrained, config) |
|
|
|