|
- #!/usr/bin/env python
- """
- Classifier is an image classifier specialization of Net.
- """
-
- import numpy as np
-
- import caffe
-
-
- class Classifier(caffe.Net):
- """
- Classifier extends Net for image class prediction
- by scaling, center cropping, or oversampling.
-
- Parameters
- ----------
- image_dims : dimensions to scale input for cropping/sampling.
- Default is to scale to net input size for whole-image crop.
- mean, input_scale, raw_scale, channel_swap: params for
- preprocessing options.
- """
- def __init__(self, model_file, pretrained_file, image_dims=None,
- mean=None, input_scale=None, raw_scale=None,
- channel_swap=None):
- caffe.Net.__init__(self, model_file, pretrained_file, caffe.TEST)
-
- # configure pre-processing
- in_ = self.inputs[0]
- self.transformer = caffe.io.Transformer(
- {in_: self.blobs[in_].data.shape})
- self.transformer.set_transpose(in_, (2, 0, 1))
- if mean is not None:
- self.transformer.set_mean(in_, mean)
- if input_scale is not None:
- self.transformer.set_input_scale(in_, input_scale)
- if raw_scale is not None:
- self.transformer.set_raw_scale(in_, raw_scale)
- if channel_swap is not None:
- self.transformer.set_channel_swap(in_, channel_swap)
-
- self.crop_dims = np.array(self.blobs[in_].data.shape[2:])
- if not image_dims:
- image_dims = self.crop_dims
- self.image_dims = image_dims
-
- def predict(self, inputs, oversample=True):
- """
- Predict classification probabilities of inputs.
-
- Parameters
- ----------
- inputs : iterable of (H x W x K) input ndarrays.
- oversample : boolean
- average predictions across center, corners, and mirrors
- when True (default). Center-only prediction when False.
-
- Returns
- -------
- predictions: (N x C) ndarray of class probabilities for N images and C
- classes.
- """
- # Scale to standardize input dimensions.
- input_ = np.zeros((len(inputs),
- self.image_dims[0],
- self.image_dims[1],
- inputs[0].shape[2]),
- dtype=np.float32)
- for ix, in_ in enumerate(inputs):
- input_[ix] = caffe.io.resize_image(in_, self.image_dims)
-
- if oversample:
- # Generate center, corner, and mirrored crops.
- input_ = caffe.io.oversample(input_, self.crop_dims)
- else:
- # Take center crop.
- center = np.array(self.image_dims) / 2.0
- crop = np.tile(center, (1, 2))[0] + np.concatenate([
- -self.crop_dims / 2.0,
- self.crop_dims / 2.0
- ])
- crop = crop.astype(int)
- input_ = input_[:, crop[0]:crop[2], crop[1]:crop[3], :]
-
- # Classify
- caffe_in = np.zeros(np.array(input_.shape)[[0, 3, 1, 2]],
- dtype=np.float32)
- for ix, in_ in enumerate(input_):
- caffe_in[ix] = self.transformer.preprocess(self.inputs[0], in_)
- out = self.forward_all(**{self.inputs[0]: caffe_in})
- predictions = out[self.outputs[0]]
-
- # For oversampling, average predictions across crops.
- if oversample:
- predictions = predictions.reshape((len(predictions) / 10, 10, -1))
- predictions = predictions.mean(1)
-
- return predictions
|