|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- """
- Data utils used in Bert finetune and evaluation.
- """
-
- import numpy as np
- class Tuple():
- """
- apply the functions to the corresponding input fields.
- """
- def __init__(self, fn, *args):
- if isinstance(fn, (list, tuple)):
- assert args, 'Input pattern not understood. The input of Tuple can be ' \
- 'Tuple(A, B, C) or Tuple([A, B, C]) or Tuple((A, B, C)). ' \
- 'Received fn=%s, args=%s' % (str(fn), str(args))
- self._fn = fn
- else:
- self._fn = (fn,) + args
- for i, ele_fn in enumerate(self._fn):
- assert callable(
- ele_fn
- ), 'Batchify functions must be callable! type(fn[%d]) = %s' % (
- i, str(type(ele_fn)))
-
- def __call__(self, data):
-
- assert len(data[0]) == len(self._fn),\
- 'The number of attributes in each data sample should contain' \
- ' {} elements'.format(len(self._fn))
- ret = []
- for i, ele_fn in enumerate(self._fn):
- result = ele_fn([ele[i] for ele in data])
- if isinstance(result, (tuple, list)):
- ret.extend(result)
- else:
- ret.append(result)
- return tuple(ret)
-
- class Pad():
- """
- pad the data with given value
- """
- def __init__(self,
- pad_val=0,
- axis=0,
- ret_length=None,
- dtype=None,
- pad_right=True):
- self._pad_val = pad_val
- self._axis = axis
- self._ret_length = ret_length
- self._dtype = dtype
- self._pad_right = pad_right
-
- def __call__(self, data):
- arrs = [np.asarray(ele) for ele in data]
- original_length = [ele.shape[self._axis] for ele in arrs]
- max_size = max(original_length)
- ret_shape = list(arrs[0].shape)
- ret_shape[self._axis] = max_size
- ret_shape = (len(arrs),) + tuple(ret_shape)
- ret = np.full(
- shape=ret_shape,
- fill_value=self._pad_val,
- dtype=arrs[0].dtype if self._dtype is None else self._dtype)
- for i, arr in enumerate(arrs):
- if arr.shape[self._axis] == max_size:
- ret[i] = arr
- else:
- slices = [slice(None) for _ in range(arr.ndim)]
- if self._pad_right:
- slices[self._axis] = slice(0, arr.shape[self._axis])
- else:
- slices[self._axis] = slice(max_size - arr.shape[self._axis],
- max_size)
-
- if slices[self._axis].start != slices[self._axis].stop:
- slices = [slice(i, i + 1)] + slices
- ret[tuple(slices)] = arr
- if self._ret_length:
- return ret, np.asarray(
- original_length,
- dtype="int32") if self._ret_length else np.asarray(
- original_length, self._ret_length)
- return ret
-
- class Stack():
- """
- Stack the input data
- """
-
- def __init__(self, axis=0, dtype=None):
- self._axis = axis
- self._dtype = dtype
-
- def __call__(self, data):
- data = np.stack(
- data,
- axis=self._axis).astype(self._dtype) if self._dtype else np.stack(
- data, axis=self._axis)
- return data
|