|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- """YOLOv3_tiny Network """
- import mindspore as ms
- import mindspore.ops as ops
- import mindspore.nn as nn
- from mindspore import Tensor, context
- from mindspore.ops import operations as P
- from mindspore.ops import functional as F
- from mindspore.ops import composite as C
-
- from mindspore.context import ParallelMode
- from mindspore.parallel._auto_parallel_context import auto_parallel_context
- from mindspore.communication.management import get_group_size
-
- from src.loss import XYLoss, WHLoss, ConfidenceLoss, ClassLoss
- from model_utils.config import config as default_config
-
-
-
- class zeropad(nn.Cell):
- """
- zeropad operator
- """
-
- def __init__(self):
- super(zeropad, self).__init__()
- self.pad = ops.Pad(paddings=((0, 0), (0, 0), (0, 1), (0, 1)))
-
- def construct(self, x):
- return self.pad(x)
-
-
- class YOLOV3_Tiny_Backbone(nn.Cell):
- """
- YoloV3_Tiny backbone
- """
-
- def __init__(self, config=default_config):
- super(YOLOV3_Tiny_Backbone, self).__init__()
- self.config = config
- self.in_channel = self.config.in_channels
- self.out_channel = self.config.out_channels
-
- self.conv0 = nn.Conv2dBnAct(3, self.in_channel[0], kernel_size=3, stride=1, has_bias=False,
- has_bn=True, momentum=0.03, alpha=0.1, activation='leakyrelu', eps=0.001)
- self.maxpool0 = nn.MaxPool2d(kernel_size=2, stride=2)
-
- self.conv1 = nn.Conv2dBnAct(self.in_channel[0], self.out_channel[0], kernel_size=3, stride=1,
- has_bias=False, has_bn=True, momentum=0.03, alpha=0.1,
- activation='leakyrelu', eps=0.001)
- self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
-
- self.conv2 = nn.Conv2dBnAct(self.in_channel[1], self.out_channel[1], kernel_size=3, stride=1,
- has_bias=False, has_bn=True, momentum=0.03, alpha=0.1,
- activation='leakyrelu', eps=0.001)
- self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
-
- self.conv3 = nn.Conv2dBnAct(self.in_channel[2], self.out_channel[2], kernel_size=3, stride=1,
- has_bias=False, has_bn=True, momentum=0.03, alpha=0.1,
- activation='leakyrelu', eps=0.001)
- self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
-
- self.conv4 = nn.Conv2dBnAct(self.in_channel[3], self.out_channel[3], kernel_size=3, stride=1,
- has_bias=False, has_bn=True, momentum=0.03, alpha=0.1,
- activation='leakyrelu', eps=0.001)
- self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
-
- self.conv5 = nn.Conv2dBnAct(self.in_channel[4], self.out_channel[4], kernel_size=3, stride=1,
- has_bias=False, has_bn=True, momentum=0.03, alpha=0.1,
- activation='leakyrelu', eps=0.001)
- self.pad = zeropad()
- self.maxpool5 = nn.SequentialCell(
- [self.pad,
- nn.MaxPool2d(kernel_size=2, stride=1)])
-
- def construct(self, x):
- c1 = self.conv0(x)
- c2 = self.maxpool0(c1)
- c3 = self.conv1(c2)
- c4 = self.maxpool1(c3)
- c5 = self.conv2(c4)
- c6 = self.maxpool2(c5)
- c7 = self.conv3(c6)
- c8 = self.maxpool3(c7)
- c9 = self.conv4(c8)
- c10 = self.maxpool4(c9)
- c11 = self.conv5(c10)
- c12 = self.maxpool5(c11)
-
- return c9, c12
-
- class Detection_Block(nn.Cell):
- """
- Detection Network
-
- Returns:
- Tuple, tuple of output tensor,(f1,f2).
-
- Examples:
- DetectionBlock(scale='l',stride=32,config=config)
-
- """
-
- def __init__(self, scale, config=default_config, training=True):
- super(Detection_Block, self).__init__()
-
- if scale == 'm':
- idx = (0, 1, 2)
- elif scale == 'l':
- idx = (3, 4, 5)
- else:
- raise KeyError("Invalid scale value for DetectionBlock")
- self.config = config
- self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
- self.num_anchors_per_scale = 3
- self.num_attrib = 4 + 1 + self.config.num_classes
- self.lambda_coord = 1
-
- self.sigmoid = nn.Sigmoid()
- self.reshape = P.Reshape()
- self.tile = P.Tile()
- self.concat = P.Concat(axis=-1)
- self.training = training
-
- def construct(self, x, input_shape):
- num_batch = P.Shape()(x)[0]
- grid_size = P.Shape()(x)[2:4]
-
- # Reshape and transpose the feature to [n, grid_size[0], grid_size[1], 3, num_attrib]
- prediction = P.Reshape()(x, (num_batch,
- self.num_anchors_per_scale,
- self.num_attrib,
- grid_size[0],
- grid_size[1]))
- prediction = P.Transpose()(prediction, (0, 3, 4, 1, 2))
-
- range_x = range(grid_size[1])
- range_y = range(grid_size[0])
- grid_x = P.Cast()(F.tuple_to_array(range_x), ms.float32)
- grid_y = P.Cast()(F.tuple_to_array(range_y), ms.float32)
- # Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid
- # [batch, gridx, gridy, 1, 1]
- grid_x = self.tile(self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1))
- grid_y = self.tile(self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1))
- # Shape is [grid_size[0], grid_size[1], 1, 2]
- grid = self.concat((grid_x, grid_y))
-
- box_xy = prediction[:, :, :, :, :2]
- box_wh = prediction[:, :, :, :, 2:4]
- box_confidence = prediction[:, :, :, :, 4:5]
- box_probs = prediction[:, :, :, :, 5:]
-
- # gridsize1 is x
- # gridsize0 is y
- box_xy = (self.sigmoid(box_xy) + grid) / P.Cast()(F.tuple_to_array((grid_size[1], grid_size[0])), ms.float32)
- # box_wh is w->h
- box_wh = P.Exp()(box_wh) * self.anchors / input_shape
- box_confidence = self.sigmoid(box_confidence)
- box_probs = self.sigmoid(box_probs)
-
- if self.training:
- return grid, prediction, box_xy, box_wh
- return self.concat((box_xy, box_wh, box_confidence, box_probs))
-
-
- class YOLOV3_Tiny(nn.Cell):
- """
- YOLOv3_Tiny Network.
-
- Note:
- backbone = YOLOv3_Tiny_Backbone.
- YOLOV3_Tiny(YOLOv3_Tiny_Backbone)
-
- """
-
- def __init__(self, config=default_config,
- backbone=YOLOV3_Tiny_Backbone,
- detectionBlock=Detection_Block,
- training=True):
- super(YOLOV3_Tiny, self).__init__()
- self.config = config
- self.backbone = backbone(self.config)
-
- self.detection_l = detectionBlock('l', self.config, training)
- self.detection_m = detectionBlock('m', self.config, training)
-
- self.conv1 = nn.Conv2dBnAct(512, 1024, kernel_size=3, stride=1, has_bias=False, has_bn=True,
- momentum=0.03, alpha=0.1, activation='leakyrelu', eps=0.001)
- self.conv2 = nn.Conv2dBnAct(1024, 256, kernel_size=1, stride=1, has_bias=False, has_bn=True,
- momentum=0.03, alpha=0.1, activation='leakyrelu', eps=0.001)
- self.conv3 = nn.Conv2dBnAct(256, 512, kernel_size=3, stride=1, has_bias=False, has_bn=True,
- momentum=0.03, alpha=0.1, activation='leakyrelu', eps=0.001)
- self.conv2d1 = nn.Conv2d(512, ((self.config.num_classes + 5) * 3),
- kernel_size=1, stride=1, has_bias=True)
-
- self.conv4 = nn.Conv2dBnAct(256, 128, kernel_size=1, stride=1, has_bias=False, has_bn=True,
- momentum=0.03, alpha=0.1, activation='leakyrelu', eps=0.001)
- self.conv5 = nn.Conv2dBnAct(384, 256, kernel_size=3, stride=1, has_bias=False, has_bn=True,
- momentum=0.03, alpha=0.1, activation='leakyrelu', eps=0.001)
- self.conv2d2 = nn.Conv2d(256, ((self.config.num_classes + 5) * 3),
- kernel_size=1, stride=1, has_bias=True)
-
- self.concat = P.Concat(axis=1)
-
- self.tenser_to_array = P.TupleToArray()
-
- def construct(self, x):
- # input_shape of x is (batch_size, 3, h, w)
- # c9 is (batch_size, backbone_shape[3], h/16, w/16)
- # c12 is (batch_size, backbone_shape[4], h/32, w/32)
- img_hight = P.Shape()(x)[2]
- img_width = P.Shape()(x)[3]
- input_shape = F.shape(x)[2:4]
- input_shape = F.cast(self.tenser_to_array(input_shape), ms.float32)
-
- c9, c12 = self.backbone(x)
- c13 = self.conv1(c12)
- c14 = self.conv2(c13)
- c15 = self.conv3(c14)
-
- c16 = self.conv4(c14)
- # nearest-Upsample
- upsample = P.ResizeNearestNeighbor((img_hight // 16, img_width // 16))(c16)
- c17 = upsample
- c18 = self.concat((c17, c9))
- c19 = self.conv5(c18)
-
- c19 = self.conv2d2(c19)
- c15 = self.conv2d1(c15)
-
- output_m = self.detection_m(c19, input_shape)
- output_l = self.detection_l(c15, input_shape)
-
- return output_l, output_m
-
-
- class Iou(nn.Cell):
- """Calculate the iou of boxes"""
-
- def __init__(self):
- super(Iou, self).__init__()
- self.min = P.Minimum()
- self.max = P.Maximum()
-
- def construct(self, box1, box2):
- # box1: pred_box [batch, gx, gy, anchors, 1, 4] ->4: [x_center, y_center, w, h]
- # box2: gt_box [batch, 1, 1, 1, maxbox, 4]
- # convert to topLeft and rightDown
- box1_xy = box1[:, :, :, :, :, :2]
- box1_wh = box1[:, :, :, :, :, 2:4]
- box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0) # topLeft
- box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0) # rightDown
-
- box2_xy = box2[:, :, :, :, :, :2]
- box2_wh = box2[:, :, :, :, :, 2:4]
- box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0)
- box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0)
-
- intersect_mins = self.max(box1_mins, box2_mins)
- intersect_maxs = self.min(box1_maxs, box2_maxs)
- intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0))
- # P.squeeze: for effiecient slice
- intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \
- P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2])
- box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2])
- box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2])
- iou = intersect_area / (box1_area + box2_area - intersect_area)
- # iou : [batch, gx, gy, anchors, maxboxes]
- return iou
-
-
- class YOLO_Loss_Block(nn.Cell):
- def __init__(self, scale, config=default_config):
- super(YOLO_Loss_Block, self).__init__()
- if scale == 'm':
- idx = (0, 1, 2)
- elif scale == 'l':
- idx = (3, 4, 5)
- else:
- raise KeyError("Invalid scale value for DetectionBlock")
- self.config = config
- self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
- self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32)
- self.concat = P.Concat(axis=-1)
- self.iou = Iou()
- self.reduce_max = P.ReduceMax(keep_dims=False)
- self.xy_loss = XYLoss()
- self.wh_loss = WHLoss()
- self.confidenceLoss = ConfidenceLoss()
- self.classLoss = ClassLoss()
-
- def construct(self, grid, prediction, pred_xy, pred_wh, y_true, gt_box, input_shape):
- # prediction : origin output from yolo
- # pred_xy: (sigmoid(xy)+grid)/grid_size
- # pred_wh: (exp(wh)*anchors)/input_shape
- # y_true : after normalize
- # gt_box: [batch, maxboxes, xyhw] after normalize
- object_mask = y_true[:, :, :, :, 4:5]
- class_probs = y_true[:, :, :, :, 5:]
-
- grid_shape = P.Shape()(prediction)[1:3]
- grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32)
-
- pred_boxes = self.concat((pred_xy, pred_wh))
- true_xy = y_true[:, :, :, :, :2] * grid_shape - grid
- true_wh = y_true[:, :, :, :, 2:4]
- true_wh = P.Select()(P.Equal()(true_wh, 0.0),
- P.Fill()(P.DType()(true_wh),
- P.Shape()(true_wh), 1.0),
- true_wh)
- true_wh = P.Log()(true_wh / self.anchors * input_shape)
- # 2-w*h for large picture, use small scale, since small obj need more precise
- box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4]
-
- gt_shape = P.Shape()(gt_box)
- gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2]))
-
- # add one more dimension for broadcast
- iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box)
- # gt_box is x,y,h,w after normalize
- # [batch, grid[0], grid[1], num_anchor, num_gt]
- best_iou = self.reduce_max(iou, -1)
- # [batch, grid[0], grid[1], num_anchor]
-
- # ignore_mask IOU too small
- ignore_mask = best_iou < self.ignore_threshold
- ignore_mask = P.Cast()(ignore_mask, ms.float32)
- ignore_mask = P.ExpandDims()(ignore_mask, -1)
- # ignore_mask backpro will cause a lot maximunGrad and minimumGrad time consume.
- # so we turn off its gradient
- ignore_mask = F.stop_gradient(ignore_mask)
-
- xy_loss = self.xy_loss(object_mask, box_loss_scale, prediction[:, :, :, :, :2], true_xy)
- wh_loss = self.wh_loss(object_mask, box_loss_scale, prediction[:, :, :, :, 2:4], true_wh)
- confidence_loss = self.confidenceLoss(object_mask, prediction[:, :, :, :, 4:5], ignore_mask)
- class_loss = self.classLoss(object_mask, prediction[:, :, :, :, 5:], class_probs)
- loss = xy_loss + wh_loss + confidence_loss + class_loss
- batch_size = P.Shape()(prediction)[0]
- return loss / batch_size
-
-
- class YOLOWithLossCell(nn.Cell):
- """
- YOLO network with Loss
- """
-
- def __init__(self, network, config=default_config):
- super(YOLOWithLossCell, self).__init__()
- self.config = config
- self.yolov3_tiny = network
- self.tenser_to_array = P.TupleToArray()
- self.loss_l = YOLO_Loss_Block('l', self.config)
- self.loss_m = YOLO_Loss_Block('m', self.config)
-
- def construct(self, x, y_true_0, y_true_1, gt_0, gt_1):
- input_shape = F.shape(x)[2:4]
- input_shape = F.cast(self.tenser_to_array(input_shape), ms.float32)
- yolo_out = self.yolov3_tiny(x)
- loss_l = self.loss_l(*yolo_out[0], y_true_0, gt_0, input_shape)
- loss_m = self.loss_m(*yolo_out[1], y_true_1, gt_1, input_shape)
-
- return loss_l + loss_m
-
-
- class TrainingWrapper(nn.Cell):
- """Training wrapper."""
- def __init__(self, network, optimizer, sens=1.0):
- super(TrainingWrapper, self).__init__(auto_prefix=False)
- self.network = network
- self.network.set_grad()
- self.weights = optimizer.parameters
- self.optimizer = optimizer
- self.grad = C.GradOperation(get_by_list=True, sens_param=True)
- self.sens = sens
- self.reducer_flag = False
- self.grad_reducer = None
- self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
- if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
- self.reducer_flag = True
- if self.reducer_flag:
- mean = context.get_auto_parallel_context("gradients_mean")
- if auto_parallel_context().get_device_num_is_set():
- degree = context.get_auto_parallel_context("device_num")
- else:
- degree = get_group_size()
- self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
-
- def construct(self, *args):
- weights = self.weights
- loss = self.network(*args)
- sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
- grads = self.grad(self.network, weights)(*args, sens)
- if self.reducer_flag:
- grads = self.grad_reducer(grads)
- self.optimizer(grads)
- return loss
|