|
- import math
- from urllib.request import urlretrieve
- import torch
- from PIL import Image
- from tqdm import tqdm
- import numpy as np
- import random
- import torch.nn.functional as F
-
- class Warp(object):
- def __init__(self, size, interpolation=Image.BILINEAR):
- self.size = int(size)
- self.interpolation = interpolation
-
- def __call__(self, img):
- return img.resize((self.size, self.size), self.interpolation)
-
- def __str__(self):
- return self.__class__.__name__ + ' (size={size}, interpolation={interpolation})'.format(size=self.size,
- interpolation=self.interpolation)
- class MultiScaleCrop(object):
-
- def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
- self.scales = scales if scales is not None else [1, 875, .75, .66]
- self.max_distort = max_distort
- self.fix_crop = fix_crop
- self.more_fix_crop = more_fix_crop
- self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
- self.interpolation = Image.BILINEAR
-
- def __call__(self, img):
- im_size = img.size
- crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
- crop_img_group = img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h))
- ret_img_group = crop_img_group.resize((self.input_size[0], self.input_size[1]), self.interpolation)
- return ret_img_group
-
- def _sample_crop_size(self, im_size):
- image_w, image_h = im_size[0], im_size[1]
-
- # find a crop size
- base_size = min(image_w, image_h)
- crop_sizes = [int(base_size * x) for x in self.scales]
- crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
- crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
-
- pairs = []
- for i, h in enumerate(crop_h):
- for j, w in enumerate(crop_w):
- if abs(i - j) <= self.max_distort:
- pairs.append((w, h))
-
- crop_pair = random.choice(pairs)
- if not self.fix_crop:
- w_offset = random.randint(0, image_w - crop_pair[0])
- h_offset = random.randint(0, image_h - crop_pair[1])
- else:
- w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
-
- return crop_pair[0], crop_pair[1], w_offset, h_offset
-
- def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
- offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
- return random.choice(offsets)
-
- @staticmethod
- def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
- w_step = (image_w - crop_w) // 4
- h_step = (image_h - crop_h) // 4
-
- ret = list()
- ret.append((0, 0)) # upper left
- ret.append((4 * w_step, 0)) # upper right
- ret.append((0, 4 * h_step)) # lower left
- ret.append((4 * w_step, 4 * h_step)) # lower right
- ret.append((2 * w_step, 2 * h_step)) # center
-
- if more_fix_crop:
- ret.append((0, 2 * h_step)) # center left
- ret.append((4 * w_step, 2 * h_step)) # center right
- ret.append((2 * w_step, 4 * h_step)) # lower center
- ret.append((2 * w_step, 0 * h_step)) # upper center
-
- ret.append((1 * w_step, 1 * h_step)) # upper left quarter
- ret.append((3 * w_step, 1 * h_step)) # upper right quarter
- ret.append((1 * w_step, 3 * h_step)) # lower left quarter
- ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
-
- return ret
-
-
- def __str__(self):
- return self.__class__.__name__
-
-
- def download_url(url, destination=None, progress_bar=True):
- """Download a URL to a local file.
-
- Parameters
- ----------
- url : str
- The URL to download.
- destination : str, None
- The destination of the file. If None is given the file is saved to a temporary directory.
- progress_bar : bool
- Whether to show a command-line progress bar while downloading.
-
- Returns
- -------
- filename : str
- The location of the downloaded file.
-
- Notes
- -----
- Progress bar use/example adapted from tqdm documentation: https://github.com/tqdm/tqdm
- """
-
- def my_hook(t):
- last_b = [0]
-
- def inner(b=1, bsize=1, tsize=None):
- if tsize is not None:
- t.total = tsize
- if b > 0:
- t.update((b - last_b[0]) * bsize)
- last_b[0] = b
-
- return inner
-
- if progress_bar:
- with tqdm(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t:
- filename, _ = urlretrieve(url, filename=destination, reporthook=my_hook(t))
- else:
- filename, _ = urlretrieve(url, filename=destination)
-
-
- class AveragePrecisionMeter(object):
- """
- The APMeter measures the average precision per class.
- The APMeter is designed to operate on `NxK` Tensors `output` and
- `target`, and optionally a `Nx1` Tensor weight where (1) the `output`
- contains model output scores for `N` examples and `K` classes that ought to
- be higher when the model is more convinced that the example should be
- positively labeled, and smaller when the model believes the example should
- be negatively labeled (for instance, the output of a sigmoid function); (2)
- the `target` contains only values 0 (for negative examples) and 1
- (for positive examples); and (3) the `weight` ( > 0) represents weight for
- each sample.
- """
-
- def __init__(self, difficult_examples=False):
- super(AveragePrecisionMeter, self).__init__()
- self.reset()
- self.difficult_examples = difficult_examples
-
- def reset(self):
- """Resets the meter with empty member variables"""
- self.scores = torch.FloatTensor(torch.FloatStorage())
- self.targets = torch.LongTensor(torch.LongStorage())
-
- def add(self, output, target):
- """
- Args:
- output (Tensor): NxK tensor that for each of the N examples
- indicates the probability of the example belonging to each of
- the K classes, according to the model. The probabilities should
- sum to one over all classes
- target (Tensor): binary NxK tensort that encodes which of the K
- classes are associated with the N-th input
- (eg: a row [0, 1, 0, 1] indicates that the example is
- associated with classes 2 and 4)
- weight (optional, Tensor): Nx1 tensor representing the weight for
- each example (each weight > 0)
- """
- if not torch.is_tensor(output):
- output = torch.from_numpy(output)
- if not torch.is_tensor(target):
- target = torch.from_numpy(target)
-
- if output.dim() == 1:
- output = output.view(-1, 1)
- else:
- assert output.dim() == 2, \
- 'wrong output size (should be 1D or 2D with one column \
- per class)'
- if target.dim() == 1:
- target = target.view(-1, 1)
- else:
- assert target.dim() == 2, \
- 'wrong target size (should be 1D or 2D with one column \
- per class)'
- if self.scores.numel() > 0:
- assert target.size(1) == self.targets.size(1), \
- 'dimensions for output should match previously added examples.'
-
- # make sure storage is of sufficient size
- if self.scores.storage().size() < self.scores.numel() + output.numel():
- new_size = math.ceil(self.scores.storage().size() * 1.5)
- self.scores.storage().resize_(int(new_size + output.numel()))
- self.targets.storage().resize_(int(new_size + output.numel()))
-
- # store scores and targets
- offset = self.scores.size(0) if self.scores.dim() > 0 else 0
- self.scores.resize_(offset + output.size(0), output.size(1))
- self.targets.resize_(offset + target.size(0), target.size(1))
- self.scores.narrow(0, offset, output.size(0)).copy_(output)
- self.targets.narrow(0, offset, target.size(0)).copy_(target)
-
- def value(self):
- """Returns the model's average precision for each class
- Return:
- ap (FloatTensor): 1xK tensor, with avg precision for each class k
- """
-
- if self.scores.numel() == 0:
- return 0
- ap = torch.zeros(self.scores.size(1))
- rg = torch.arange(1, self.scores.size(0)).float()
- # compute average precision for each class
- for k in range(self.scores.size(1)):
- # sort scores
- scores = self.scores[:, k]
- targets = self.targets[:, k]
- # compute average precision
- ap[k] = AveragePrecisionMeter.average_precision(scores, targets, self.difficult_examples)
- return ap
-
- @staticmethod
- def average_precision(output, target, difficult_examples=True):
-
- # sort examples
- sorted, indices = torch.sort(output, dim=0, descending=True)
-
- # Computes prec@i
- pos_count = 0.
- total_count = 0.
- precision_at_i = 0.
- for i in indices:
- label = target[i]
- if difficult_examples and label == 0:
- continue
- if label == 1:
- pos_count += 1
- total_count += 1
- if label == 1:
- precision_at_i += pos_count / total_count
- precision_at_i /= pos_count
- return precision_at_i
-
- def overall(self):
- if self.scores.numel() == 0:
- return 0
- scores = self.scores.cpu().numpy()
- targets = self.targets.cpu().numpy()
- targets[targets == -1] = 0
- return self.evaluation(scores, targets)
-
- def overall_topk(self, k):
- targets = self.targets.cpu().numpy()
- targets[targets == -1] = 0
- n, c = self.scores.size()
- scores = np.zeros((n, c)) - 1
- index = self.scores.topk(k, 1, True, True)[1].cpu().numpy()
- tmp = self.scores.cpu().numpy()
- for i in range(n):
- for ind in index[i]:
- scores[i, ind] = 1 if tmp[i, ind] >= 0 else -1
- return self.evaluation(scores, targets)
-
-
- def evaluation(self, scores_, targets_):
- n, n_class = scores_.shape
- Nc, Np, Ng = np.zeros(n_class), np.zeros(n_class), np.zeros(n_class)
- for k in range(n_class):
- scores = scores_[:, k]
- targets = targets_[:, k]
- targets[targets == -1] = 0
- Ng[k] = np.sum(targets == 1)
- Np[k] = np.sum(scores >= 0)
- Nc[k] = np.sum(targets * (scores >= 0))
- Np[Np == 0] = 1
- OP = np.sum(Nc) / np.sum(Np)
- OR = np.sum(Nc) / np.sum(Ng)
- OF1 = (2 * OP * OR) / (OP + OR)
-
- CP = np.sum(Nc / Np) / n_class
- CR = np.sum(Nc / Ng) / n_class
- CF1 = (2 * CP * CR) / (CP + CR)
- return OP, OR, OF1, CP, CR, CF1
-
- def gen_A(num_classes, t, adj_file):
- import pickle
- result = pickle.load(open(adj_file, 'rb'))
- _adj = result['adj']
- _nums = result['nums']
- _nums = _nums[:, np.newaxis]
- _adj = _adj / _nums
- _adj[_adj < t] = 0
- _adj[_adj >= t] = 1
- _adj = _adj * 0.25 / (_adj.sum(0, keepdims=True) + 1e-6)
- _adj = _adj + np.identity(num_classes, np.int)
- return _adj
-
- def gen_adj(A):
- D = torch.pow(A.sum(1).float(), -0.5)
- D = torch.diag(D)
- adj = torch.matmul(torch.matmul(A, D).t(), D)
- return adj
|