|
- import torch
-
- from mmcv.cnn.bricks import GeneralizedAttention
-
-
- def test_context_block():
-
- # test attention_type='1000'
- imgs = torch.randn(2, 16, 20, 20)
- gen_attention_block = GeneralizedAttention(16, attention_type='1000')
- assert gen_attention_block.query_conv.in_channels == 16
- assert gen_attention_block.key_conv.in_channels == 16
- assert gen_attention_block.key_conv.in_channels == 16
- out = gen_attention_block(imgs)
- assert out.shape == imgs.shape
-
- # test attention_type='0100'
- imgs = torch.randn(2, 16, 20, 20)
- gen_attention_block = GeneralizedAttention(16, attention_type='0100')
- assert gen_attention_block.query_conv.in_channels == 16
- assert gen_attention_block.appr_geom_fc_x.in_features == 8
- assert gen_attention_block.appr_geom_fc_y.in_features == 8
- out = gen_attention_block(imgs)
- assert out.shape == imgs.shape
-
- # test attention_type='0010'
- imgs = torch.randn(2, 16, 20, 20)
- gen_attention_block = GeneralizedAttention(16, attention_type='0010')
- assert gen_attention_block.key_conv.in_channels == 16
- assert hasattr(gen_attention_block, 'appr_bias')
- out = gen_attention_block(imgs)
- assert out.shape == imgs.shape
-
- # test attention_type='0001'
- imgs = torch.randn(2, 16, 20, 20)
- gen_attention_block = GeneralizedAttention(16, attention_type='0001')
- assert gen_attention_block.appr_geom_fc_x.in_features == 8
- assert gen_attention_block.appr_geom_fc_y.in_features == 8
- assert hasattr(gen_attention_block, 'geom_bias')
- out = gen_attention_block(imgs)
- assert out.shape == imgs.shape
-
- # test spatial_range >= 0
- imgs = torch.randn(2, 256, 20, 20)
- gen_attention_block = GeneralizedAttention(256, spatial_range=10)
- assert hasattr(gen_attention_block, 'local_constraint_map')
- out = gen_attention_block(imgs)
- assert out.shape == imgs.shape
-
- # test q_stride > 1
- imgs = torch.randn(2, 16, 20, 20)
- gen_attention_block = GeneralizedAttention(16, q_stride=2)
- assert gen_attention_block.q_downsample is not None
- out = gen_attention_block(imgs)
- assert out.shape == imgs.shape
-
- # test kv_stride > 1
- imgs = torch.randn(2, 16, 20, 20)
- gen_attention_block = GeneralizedAttention(16, kv_stride=2)
- assert gen_attention_block.kv_downsample is not None
- out = gen_attention_block(imgs)
- assert out.shape == imgs.shape
|