|
- import os
- import sys
- import numpy as np
- from six.moves import cPickle
-
- import keras.backend as K
-
- num_train_samples = 60000
-
-
- def load_batch(fpath, label_key='labels'):
- """Internal utility for parsing CIFAR data.
-
- # Arguments
- fpath: path the file to parse.
- label_key: key for label data in the retrieve
- dictionary.
-
- # Returns
- A tuple `(data, labels)`.
- """
- with open(fpath, 'rb') as f:
- if sys.version_info < (3,):
- d = cPickle.load(f)
- else:
- d = cPickle.load(f, encoding='bytes')
- # decode utf8
- d_decoded = {}
- for k, v in d.items():
- d_decoded[k.decode('utf8')] = v
- d = d_decoded
- data = d['data']
- labels = d[label_key]
-
- data = data.reshape(data.shape[0], 3, 32, 32)
- return data, labels
-
-
- def load_data(train_dir, val_dir):
- """Loads part of CIFAR10 dataset.
-
- # Returns
- Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
- """
-
- x_train, y_train = load_batch(os.path.join(train_dir, 'train_batch'))
-
- x_test, y_test = load_batch(os.path.join(val_dir, 'test_batch'))
-
- y_train = np.reshape(y_train, (len(y_train), 1))
- y_test = np.reshape(y_test, (len(y_test), 1))
-
- if K.image_data_format() == 'channels_last':
- x_train = x_train.transpose(0, 2, 3, 1)
- x_test = x_test.transpose(0, 2, 3, 1)
-
- return (x_train, y_train), (x_test, y_test)
-
-
- def load_train_data(train_dir):
- x_train, y_train = load_batch(os.path.join(train_dir, 'train_batch'))
- y_train = np.reshape(y_train, (len(y_train), 1))
- if K.image_data_format() == 'channels_last':
- x_train = x_train.transpose(0, 2, 3, 1)
- return (x_train, y_train)
-
-
- def load_test_data(val_dir):
- x_test, y_test = load_batch(os.path.join(val_dir, 'test_batch'))
- y_test = np.reshape(y_test, (len(y_test), 1))
- if K.image_data_format() == 'channels_last':
- x_test = x_test.transpose(0, 2, 3, 1)
- return (x_test, y_test)
-
-
- def split_cifar10_2_parts(data_root):
- def _load_batch(fpath):
- with open(fpath, 'rb') as f:
- if sys.version_info < (3,):
- d = cPickle.load(f)
- else:
- d = cPickle.load(f, encoding='bytes')
- # decode utf8
- d_decoded = {}
- for k, v in d.items():
- d_decoded[k.decode('utf8')] = v
- d = d_decoded
- return d
- # cifar10是32x32,3通道的彩色图片
- data = np.empty((num_train_samples, 3*32*32), dtype=np.uint8)
- labels = []
- filenames = []
- for i in range(1, 6):
- fpath = os.path.join(data_root, 'data_batch_' + str(i))
- d = _load_batch(fpath)
- data[(i-1)*10000: i*10000, ...] = d['data']
- labels.extend(d['labels'])
- filenames.extend(d['filenames'])
-
- with open(os.path.join(data_root, 'data_batch_a'), 'wb') as f:
- cPickle.dump(
- {b'data': data[:25000], b'labels': labels[:25000], b'filenames': filenames[:25000]}, f)
-
- with open(os.path.join(data_root, 'data_batch_b'), 'wb') as f:
- cPickle.dump(
- {b'data': data[25000:], b'labels': labels[25000:], b'filenames': filenames[25000:]}, f)
-
-
- if __name__ == "__main__":
- pass
- # data_root='/root/xy/FederalAI/docker/docker_file/cifar-10-batches-py'
- # (x_train, y_train), (x_test, y_test) = load_data(data_root)
|