|
- """ Contains functions used to sanitize and prepare the output of Yolact. """
-
-
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- import numpy as np
- import cv2
-
- from data import cfg, mask_type, MEANS, STD, activation_func
- from utils.augmentations import Resize
- from utils import timer
- from .box_utils import crop, sanitize_coordinates
-
- def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear',
- visualize_lincomb=False, crop_masks=True, score_threshold=0):
- """
- Postprocesses the output of Yolact on testing mode into a format that makes sense,
- accounting for all the possible configuration settings.
-
- Args:
- - det_output: The lost of dicts that Detect outputs.
- - w: The real with of the image.
- - h: The real height of the image.
- - batch_idx: If you have multiple images for this batch, the image's index in the batch.
- - interpolation_mode: Can be 'nearest' | 'area' | 'bilinear' (see torch.nn.functional.interpolate)
-
- Returns 4 torch Tensors (in the following order):
- - classes [num_det]: The class idx for each detection.
- - scores [num_det]: The confidence score for each detection.
- - boxes [num_det, 4]: The bounding box for each detection in absolute point form.
- - masks [num_det, h, w]: Full image masks for each detection.
- """
-
- dets = det_output[batch_idx]
- net = dets['net']
- dets = dets['detection']
-
- if dets is None:
- return [paddle.Tensor()] * 4 # Warning, this is 4 copies of the same thing
-
- if score_threshold > 0:
- keep = dets['score'] > score_threshold
-
- for k in dets:
- if k != 'proto':
- dets[k] = paddle.gather(dets[k],paddle.nonzero(keep).flatten(),axis=0)
- #dets[k] = dets[k][keep]
-
- if dets['score'].shape[0] == 0:
- return [paddle.Tensor()] * 4
-
- # Actually extract everything from dets now
- classes = dets['class']
- boxes = dets['box']
- scores = dets['score']
- masks = dets['mask']
-
- if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch:
- # At this points masks is only the coefficients
- proto_data = dets['proto']
-
- # Test flag, do not upvote
- if cfg.mask_proto_debug:
- np.save('scripts/proto.npy', proto_data.cpu().numpy())
-
- if visualize_lincomb:
- display_lincomb(proto_data, masks)
-
- masks = paddle.mm(proto_data,masks.t())
- #masks = proto_data @ masks.t()
- masks = cfg.mask_proto_mask_activation(masks)
-
- # Crop masks before upsampling because you know why
- if crop_masks:
- masks = crop(masks, boxes)
-
- # Permute into the correct output shape [num_dets, proto_h, proto_w]
- masks = masks.transpose((2, 0, 1))#.contiguous()
-
- if cfg.use_maskiou: # False
- with timer.env('maskiou_net'):
- with paddle.no_grad():
- maskiou_p = net.maskiou_net(masks.unsqueeze(1))
- maskiou_p = paddle.gather(maskiou_p, axis=1, index=classes.unsqueeze(1))[:,0]
- if cfg.rescore_mask:
- if cfg.rescore_bbox:
- scores = scores * maskiou_p
- else:
- scores = [scores, scores * maskiou_p]
-
- # Scale masks up to the full image
- masks = F.interpolate(masks.unsqueeze(0), (h, w), mode=interpolation_mode, align_corners=False).squeeze(0)
-
- # Binarize the masks
- masks = masks > 0.5
- #masks.gt_(0.5)
-
-
- boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0], boxes[:, 2], w, cast=False)
- boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1], boxes[:, 3], h, cast=False)
- boxes = boxes.astype('int64')#.long()
-
- if cfg.mask_type == mask_type.direct and cfg.eval_mask_branch: # False
- # Upscale masks
- full_masks = paddle.zeros(masks.shape[0], h, w)
-
- for jdx in range(masks.shape[0]):
- x1, y1, x2, y2 = boxes[jdx, :]
-
- mask_w = x2 - x1
- mask_h = y2 - y1
-
- # Just in case
- if mask_w * mask_h <= 0 or mask_w < 0:
- continue
-
- mask = masks[jdx, :].reshape(1, 1, cfg.mask_size, cfg.mask_size)
- mask = F.interpolate(mask, (mask_h, mask_w), mode=interpolation_mode, align_corners=False)
- mask = mask.gt(0.5).float()
- full_masks[jdx, y1:y2, x1:x2] = mask
-
- masks = full_masks
-
- return classes, scores, boxes, masks
-
-
-
-
-
- def undo_image_transformation(img, w, h):
- """
- Takes a transformed image tensor and returns a numpy ndarray that is untransformed.
- Arguments w and h are the original height and width of the image.
- """
- img_numpy = img.permute(1, 2, 0).cpu().numpy()
- img_numpy = img_numpy[:, :, (2, 1, 0)] # To BRG
-
- if cfg.backbone.transform.normalize:
- img_numpy = (img_numpy * np.array(STD) + np.array(MEANS)) / 255.0
- elif cfg.backbone.transform.subtract_means:
- img_numpy = (img_numpy / 255.0 + np.array(MEANS) / 255.0).astype(np.float32)
-
- img_numpy = img_numpy[:, :, (2, 1, 0)] # To RGB
- img_numpy = np.clip(img_numpy, 0, 1)
-
- return cv2.resize(img_numpy, (w,h))
-
-
- def display_lincomb(proto_data, masks):
- out_masks = paddle.matmul(proto_data, masks.t())
- # out_masks = cfg.mask_proto_mask_activation(out_masks)
-
- for kdx in range(1):
- jdx = kdx + 0
- import matplotlib.pyplot as plt
- coeffs = masks[jdx, :].cpu().numpy()
- idx = np.argsort(-np.abs(coeffs))
- # plt.bar(list(range(idx.shape[0])), coeffs[idx])
- # plt.show()
-
- coeffs_sort = coeffs[idx]
- arr_h, arr_w = (4,8)
- proto_h, proto_w, _ = proto_data.size()
- arr_img = np.zeros([proto_h*arr_h, proto_w*arr_w])
- arr_run = np.zeros([proto_h*arr_h, proto_w*arr_w])
- test = paddle.sum(proto_data, -1).cpu().numpy()
-
- for y in range(arr_h):
- for x in range(arr_w):
- i = arr_w * y + x
-
- if i == 0:
- running_total = proto_data[:, :, idx[i]].cpu().numpy() * coeffs_sort[i]
- else:
- running_total += proto_data[:, :, idx[i]].cpu().numpy() * coeffs_sort[i]
-
- running_total_nonlin = running_total
- if cfg.mask_proto_mask_activation == activation_func.sigmoid:
- running_total_nonlin = (1/(1+np.exp(-running_total_nonlin)))
-
- arr_img[y*proto_h:(y+1)*proto_h, x*proto_w:(x+1)*proto_w] = (proto_data[:, :, idx[i]] / paddle.max(proto_data[:, :, idx[i]])).cpu().numpy() * coeffs_sort[i]
- arr_run[y*proto_h:(y+1)*proto_h, x*proto_w:(x+1)*proto_w] = (running_total_nonlin > 0.5).astype(np.float)
- plt.imshow(arr_img)
- plt.show()
- # plt.imshow(arr_run)
- # plt.show()
- # plt.imshow(test)
- # plt.show()
- plt.imshow(out_masks[:, :, jdx].cpu().numpy())
- plt.show()
|