|
- from collections import OrderedDict
-
- import torch.nn as nn
- from torch.nn.parallel import DataParallel
-
- from mmcv.parallel.registry import MODULE_WRAPPERS
- from mmcv.runner.checkpoint import get_state_dict
-
-
- @MODULE_WRAPPERS.register_module()
- class DDPWrapper(object):
-
- def __init__(self, module):
- self.module = module
-
-
- class Block(nn.Module):
-
- def __init__(self):
- super().__init__()
- self.conv = nn.Conv2d(3, 3, 1)
- self.norm = nn.BatchNorm2d(3)
-
-
- class Model(nn.Module):
-
- def __init__(self):
- super().__init__()
- self.block = Block()
- self.conv = nn.Conv2d(3, 3, 1)
-
-
- def assert_tensor_equal(tensor_a, tensor_b):
- assert tensor_a.eq(tensor_b).all()
-
-
- def test_get_state_dict():
- state_dict_keys = set([
- 'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
- 'block.norm.bias', 'block.norm.running_mean', 'block.norm.running_var',
- 'block.norm.num_batches_tracked', 'conv.weight', 'conv.bias'
- ])
-
- model = Model()
- state_dict = get_state_dict(model)
- assert isinstance(state_dict, OrderedDict)
- assert set(state_dict.keys()) == state_dict_keys
-
- assert_tensor_equal(state_dict['block.conv.weight'],
- model.block.conv.weight)
- assert_tensor_equal(state_dict['block.conv.bias'], model.block.conv.bias)
- assert_tensor_equal(state_dict['block.norm.weight'],
- model.block.norm.weight)
- assert_tensor_equal(state_dict['block.norm.bias'], model.block.norm.bias)
- assert_tensor_equal(state_dict['block.norm.running_mean'],
- model.block.norm.running_mean)
- assert_tensor_equal(state_dict['block.norm.running_var'],
- model.block.norm.running_var)
- assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
- model.block.norm.num_batches_tracked)
- assert_tensor_equal(state_dict['conv.weight'], model.conv.weight)
- assert_tensor_equal(state_dict['conv.bias'], model.conv.bias)
-
- wrapped_model = DDPWrapper(model)
- state_dict = get_state_dict(wrapped_model)
- assert isinstance(state_dict, OrderedDict)
- assert set(state_dict.keys()) == state_dict_keys
- assert_tensor_equal(state_dict['block.conv.weight'],
- wrapped_model.module.block.conv.weight)
- assert_tensor_equal(state_dict['block.conv.bias'],
- wrapped_model.module.block.conv.bias)
- assert_tensor_equal(state_dict['block.norm.weight'],
- wrapped_model.module.block.norm.weight)
- assert_tensor_equal(state_dict['block.norm.bias'],
- wrapped_model.module.block.norm.bias)
- assert_tensor_equal(state_dict['block.norm.running_mean'],
- wrapped_model.module.block.norm.running_mean)
- assert_tensor_equal(state_dict['block.norm.running_var'],
- wrapped_model.module.block.norm.running_var)
- assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
- wrapped_model.module.block.norm.num_batches_tracked)
- assert_tensor_equal(state_dict['conv.weight'],
- wrapped_model.module.conv.weight)
- assert_tensor_equal(state_dict['conv.bias'],
- wrapped_model.module.conv.bias)
-
- # wrapped inner module
- for name, module in wrapped_model.module._modules.items():
- module = DataParallel(module)
- wrapped_model.module._modules[name] = module
- state_dict = get_state_dict(wrapped_model)
- assert isinstance(state_dict, OrderedDict)
- assert set(state_dict.keys()) == state_dict_keys
- assert_tensor_equal(state_dict['block.conv.weight'],
- wrapped_model.module.block.module.conv.weight)
- assert_tensor_equal(state_dict['block.conv.bias'],
- wrapped_model.module.block.module.conv.bias)
- assert_tensor_equal(state_dict['block.norm.weight'],
- wrapped_model.module.block.module.norm.weight)
- assert_tensor_equal(state_dict['block.norm.bias'],
- wrapped_model.module.block.module.norm.bias)
- assert_tensor_equal(state_dict['block.norm.running_mean'],
- wrapped_model.module.block.module.norm.running_mean)
- assert_tensor_equal(state_dict['block.norm.running_var'],
- wrapped_model.module.block.module.norm.running_var)
- assert_tensor_equal(
- state_dict['block.norm.num_batches_tracked'],
- wrapped_model.module.block.module.norm.num_batches_tracked)
- assert_tensor_equal(state_dict['conv.weight'],
- wrapped_model.module.conv.module.weight)
- assert_tensor_equal(state_dict['conv.bias'],
- wrapped_model.module.conv.module.bias)
|