|
- # encoding: utf-8
- '''
- 跨模态行人重识别算法封装示例
- '''
- import torch
- from torch import nn
- import torch.nn.functional as F
- from backbones.resnet import ResNet, Bottleneck
-
-
- class XIVReID(nn.Module):
- '''
- 跨模态行人重识别算法
- '''
- def __init__(self, classes = 395):
- '''
- 初始化算法模型。
- Arg
- ---
- input:
- classes: number of classes in the dataset
- type: int
- '''
-
- super(XIVReID, self).__init__()
-
- self.base = ResNet(last_stride=1, block=Bottleneck, layers=[3, 4, 6, 3])
- self.gap = nn.AdaptiveAvgPool2d(1)
-
- self.bottleneck = nn.BatchNorm1d(2048)
- self.bottleneck.bias.requires_grad_(False)
- self.classifier = nn.Linear(2048, classes, bias=False)
-
- self.generate1 = nn.Linear(3, 1)
- self.generate2 = nn.Linear(1, 3)
- self.load_state_dict(torch.load('/XIVReID/XIVReID/model.pth'))
-
- def __call__(self, x):
- '''
- 提取输入图片的特征; 对图像进行分类, 得到每个类别的概率
- Arg
- ---
- input: color image in RGB format
- shape: (N, 3, H, W), N is batch size
- type: tensor.FloatTensor
- Return
- ---
- feature of the input picture
- shape: (N, 2048)
- type: tensor.FloatTensor
- class of the input picture
- shape: (N, classes)
- type: tensor.FloatTensor
- '''
- x = torch.from_numpy(x).float()
- with torch.no_grad():
- global_feat = self.gap(self.base(x))
- global_feat = global_feat.view(global_feat.shape[0], -1)
- feat = self.classifier(global_feat)
- return global_feat.cpu().numpy()
-
- def get_feat_and_cat(self, x):
- '''
- 提取输入图片的特征; 对图像进行分类, 得到每个类别的概率
- Arg
- ---
- input: color image in RGB format
- shape: (N, 3, H, W), N is batch size
- type: tensor.FloatTensor
- Return
- ---
- feature of the input picture
- shape: (N, 2048)
- type: tensor.FloatTensor
- class of the input picture
- shape: (N, classes)
- type: tensor.FloatTensor
- '''
- with torch.no_grad():
- global_feat = self.gap(self.base(x))
- global_feat = global_feat.view(global_feat.shape[0], -1)
- feat = self.classifier(global_feat)
- return global_feat, F.softmax(feat, dim=1)
|