You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

80 lines
2.7 KiB

  1. # -*- coding:utf-8 -*-
  2. import numpy as np
  3. from joblib import Parallel, delayed
  4. from .log_utils import get_logger
  5. LOGGER = get_logger('win.win_helper')
  6. def get_windows_channel(X, X_win, des_id, nw, nh, win_x, win_y, stride_x, stride_y):
  7. """
  8. X: N x C x H x W
  9. X_win: N x nc x nh x nw
  10. (k, di, dj) in range(X.channle, win_y, win_x)
  11. """
  12. #des_id = (k * win_y + di) * win_x + dj
  13. dj = des_id % win_x
  14. di = des_id / win_x % win_y
  15. k = des_id / win_x / win_y
  16. src = X[:, k, dj:dj+nh*stride_x:stride_x, di:di+nw*stride_y:stride_y].ravel()
  17. des = X_win[des_id, :]
  18. np.copyto(des, src)
  19. def get_windows(X, win_x, win_y, stride_x=1, stride_y=1, pad_x=0, pad_y=0):
  20. """
  21. parallizing get_windows
  22. Arguments:
  23. X (ndarray): n x c x h x w
  24. Return:
  25. X_win (ndarray): n x nh x nw x nc
  26. """
  27. assert len(X.shape) == 4
  28. n, c, h, w = X.shape
  29. if pad_y > 0:
  30. X = np.concatenate(( X, np.zeros((n, c, pad_y, w),dtype=X.dtype) ), axis=2)
  31. X = np.concatenate(( np.zeros((n, c, pad_y, w),dtype=X.dtype), X ), axis=2)
  32. n, c, h, w = X.shape
  33. if pad_x > 0:
  34. X = np.concatenate(( X, np.zeros((n, c, h, pad_x),dtype=X.dtype) ), axis=3)
  35. X = np.concatenate(( np.zeros((n, c, h, pad_x),dtype=X.dtype), X ), axis=3)
  36. n, c, h, w = X.shape
  37. nc = win_y * win_x * c
  38. nh = (h - win_x) / stride_x + 1
  39. nw = (w - win_y) / stride_y + 1
  40. X_win = np.empty(( nc, n * nh * nw), dtype=np.float32)
  41. LOGGER.info("get_windows_start: X.shape={}, X_win.shape={}, nw={}, nh={}, c={}, win_x={}, win_y={}, stride_x={}, stride_y={}".format(
  42. X.shape, X_win.shape, nw, nh, c, win_x, win_y, stride_x, stride_y))
  43. Parallel(n_jobs=-1, backend="threading", verbose=0)(
  44. delayed(get_windows_channel)(X, X_win, des_id, nw, nh, win_x, win_y, stride_x, stride_y)
  45. for des_id in range(c * win_x * win_y))
  46. LOGGER.info("get_windows_end")
  47. X_win = X_win.transpose((1, 0))
  48. X_win = X_win.reshape((n, nh, nw, nc))
  49. return X_win
  50. def calc_accuracy(y_gt, y_pred, tag):
  51. LOGGER.info("Accuracy({})={:.2f}%".format(tag, np.sum(y_gt==y_pred)*100./len(y_gt)))
  52. def win_vote(y_win_predict, n_classes):
  53. """
  54. y_win_predict (ndarray): n x n_window
  55. y_win_predict[i, j] prediction for the ith data of jth window
  56. """
  57. y_pred = np.zeros(len(y_win_predict), dtype=np.int16)
  58. for i, y_bag in enumerate(y_win_predict):
  59. y_pred[i] = np.argmax(np.bincount(y_bag,minlength=n_classes))
  60. return y_pred
  61. def win_avg(y_win_proba):
  62. """
  63. Parameters
  64. ----------
  65. y_win_proba: n x n_windows x n_classes
  66. """
  67. n_classes = y_win_proba.shape[-1]
  68. y_bag_proba = np.mean(y_win_proba, axis=1)
  69. y_pred = np.argmax(y_bag_proba, axis=1)
  70. return y_pred