|
- import torch
- from torch import nn
-
-
- # TODO: in_dim, out_dim
- class PA_layer(nn.Module):
- def __init__(self, n_length):
- super(PA_layer, self).__init__()
- self.shallow_conv = nn.Conv2d(3, 8, 7, 1, 3)
- self.n_length = n_length
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.normal_(m.weight.data, 0, 0.001)
- nn.init.constant_(m.bias.data, 0)
-
- def forward(self, x):
- x = x.view((-1, 3) + x.size()[-2:])
- x = self.shallow_conv(x)
- x = x.view(-1, self.n_length, x.size(-3), x.size(-2) * x.size(-1))
- for i in range(self.n_length - 1):
- d_i = nn.PairwiseDistance(p=2)(x[:, i, :, :], x[:, i + 1, :, :]).unsqueeze(1)
- d = d_i if i == 0 else torch.cat((d, d_i), 1)
- PA = d.view(-1, 1 * (self.n_length - 1), 224, 224)
- return PA
-
-
- # TODO: VIP_level
- class VIP_layer(nn.Module):
- def __init__(self, n_segment, feature_dim, num_class, dropout_ratio, VIP_level=3):
- super(VIP_layer, self).__init__()
- self.VIP_1 = nn.MaxPool3d((8, 1, 1), 1, 0, (1, 1, 1))
- self.VIP_2 = nn.MaxPool3d((4, 1, 1), 1, 0, (2, 1, 1))
- self.VIP_4 = nn.MaxPool3d((2, 1, 1), 1, 0, (4, 1, 1))
- self.VIP_1_dropout = nn.Dropout(p=dropout_ratio)
- self.VIP_2_dropout = nn.Dropout(p=dropout_ratio)
- self.VIP_4_dropout = nn.Dropout(p=dropout_ratio)
- self.VIP_1_pred = nn.Linear(1 * feature_dim, num_class)
- self.VIP_2_pred = nn.Linear(2 * feature_dim, num_class)
- self.VIP_4_pred = nn.Linear(4 * feature_dim, num_class)
- self.n_segment = n_segment
- self.feature_dim = feature_dim
- for m in self.modules():
- if isinstance(m, nn.Linear):
- nn.init.normal_(m.weight.data, 0, 0.001)
- nn.init.constant_(m.bias.data, 0)
-
- def forward(self, x):
- x = x.view(-1, self.n_segment, x.size(1), 1, 1).permute(0, 2, 1, 3, 4)
- x1 = self.VIP_1(x)
- x1 = self.VIP_1_dropout(x1)
- x1 = self.VIP_1_pred(x1.view(-1, 1 * self.feature_dim))
- x2 = self.VIP_2(x)
- x2 = self.VIP_2_dropout(x2)
- x2 = self.VIP_2_pred(x2.view(-1, 2 * self.feature_dim))
- x4 = self.VIP_4(x)
- x4 = self.VIP_4_dropout(x4)
- x4 = self.VIP_4_pred(x4.view(-1, 4 * self.feature_dim))
- x = 1.0 * x1 + 2.0 * x2 + 4.0 * x4
- return x
|