|
-
- import pickle
- import numpy as np
- import os
- import random
- import shutil
-
-
- def get_data(input_data_folder, input_data_fileList):
- data = []
- targets = []
- for input_data_file in input_data_fileList:
- file_path = os.path.join(input_data_folder, 'cifar-10-batches-py', input_data_file)
- with open(file_path, 'rb') as f:
- entry = pickle.load(f, encoding='latin1')
- data.append(entry['data'])
- targets.extend(entry['labels'])
- data = np.vstack(data)
- return data, targets
-
-
-
- def random_split_list(input_list, ratios, random_flag=True):
- input_len = len(input_list)
- input_indices = list(range(input_len))
- if random_flag:
- random.shuffle(input_indices)
- split_indices_list = []
- split_start_index = 0
- for ratio in ratios:
- split_end_index = split_start_index + int(input_len * ratio)
- split_indices_list.append(input_indices[split_start_index: split_end_index])
- split_start_index = split_end_index
- outputs_list = []
- for split_indices in split_indices_list:
- outputs_list.append([input_list[split_index] for split_index in split_indices])
- remainder_list = input_list[split_start_index:]
- return outputs_list, remainder_list
-
- def split_indices_by_label(input_list, classes=10):
- outputs_list = [[] for _ in range(classes)]
- for index, input_label in enumerate(input_list):
- outputs_list[input_label].append(index)
- return outputs_list
-
- def split_indices_list(target, ratios_list):
- classes = len(ratios_list)
- split_len = len(ratios_list[0])
- output_indices_list = [[] for _ in range(split_len)]
- remainter_indices = []
-
- indices_by_label_list = split_indices_by_label(target, classes)
- for label, indices_by_label in enumerate(indices_by_label_list):
- ratios = ratios_list[label]
- label_outputs_list, label_remainder_list = random_split_list(indices_by_label, ratios, random_flag=False)
- for output_indices, label_outputs in zip(output_indices_list, label_outputs_list):
- output_indices.extend(label_outputs)
- remainter_indices.extend(label_remainder_list)
- for output_indices in output_indices_list:
- output_indices.sort()
- remainter_indices.sort()
- return output_indices_list, remainter_indices
-
- def split_data_and_target(data, target, ratios_list):
- output_indices_list, remainter_indices = split_indices_list(target, ratios_list)
- split_len = len(ratios_list[0])
- split_data_list = [[] for _ in range(split_len)]
- split_target_list = [[] for _ in range(split_len)]
- for data_index, output_indices in enumerate(output_indices_list):
- for output_index in output_indices:
- split_data_list[data_index].append(data[output_index])
- split_target_list[data_index].append(target[output_index])
-
- for data_index, split_data in enumerate(split_data_list):
- split_data_list[data_index] = np.vstack(split_data)
-
- return split_data_list, split_target_list
-
-
- def write_current_data_and_target(file_path, data, target, mode='python'):
- if mode == 'python':
- with open(file_path, 'wb') as f:
- entry = {'data': data, 'labels': target}
- pickle.dump(entry, f)
- else:
- assert mode == 'bin'
- with open(file_path, 'wb') as f:
- line_len = len(target)
- for index in range(line_len):
- line = int(target[index]).to_bytes(1, 'big')
- line = line + data[index].tobytes()
- f.write(line)
-
- def write_data_and_target(output_folder, output_data_fileList, data, target, mode='python'):
- file_len = len(output_data_fileList)
- if data.shape[0] % file_len == 0:
- data_piece_size = data.shape[0] / file_len
- else:
- data_piece_size = data.shape[0] / file_len + 1
- for index in range(file_len):
- current_data = data[int(index*data_piece_size) : int((index+1)*data_piece_size)]
- current_target = target[int(index*data_piece_size) : int((index+1) * data_piece_size)]
- if mode == 'python':
- folder_path = os.path.join(output_folder, 'cifar-10-batches-py')
- if not os.path.exists(folder_path):
- os.makedirs(folder_path)
- current_output_file = os.path.join(folder_path, output_data_fileList[index])
- else:
- assert mode == 'bin'
- folder_path = os.path.join(output_folder, 'mindspore_train')
- if not os.path.exists(folder_path):
- os.makedirs(folder_path)
- current_output_file = os.path.join(folder_path, output_data_fileList[index]+'.bin')
- write_current_data_and_target(current_output_file, current_data, current_target, mode)
-
- def copy_meta_files(input_folder, output_folder):
- file_list = ['batches.meta', 'test_batch']
- base_folder = 'cifar-10-batches-py'
- for file_name in file_list:
- shutil.copyfile(os.path.join(input_folder, base_folder, file_name), os.path.join(output_folder, base_folder, file_name))
-
-
-
-
- def data_split_process():
- input_data_folder = './data/cifar-10-py'
- data_fileList = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']
- output_data_folder = './data/cifar10_split_to_3'
- data, targets = get_data(input_data_folder, data_fileList)
- ratios_dict = {'non_iid_train': [[0.6, 0.2, 0.2]] * 3 + [[0.2, 0.6, 0.2]] * 3 + [[0.2, 0.2, 0.6]] * 3 + [[0.3, 0.3, 0.3]]}
- mode_list = ['python']
- for ratios_key, ratios_value in ratios_dict.items():
- split_data_list, split_target_list = split_data_and_target(data, targets, ratios_value)
- for client_rank, split_entry in enumerate(zip(split_data_list, split_target_list)):
- split_data, split_target = split_entry
- for mode in mode_list:
- current_output_folder = os.path.join(output_data_folder, ratios_key, 'client_{}'.format(client_rank))
- write_data_and_target(current_output_folder, data_fileList, split_data, split_target, mode)
- copy_meta_files(input_data_folder, current_output_folder)
- for mode in mode_list:
- current_output_folder = os.path.join(output_data_folder, ratios_key, 'all')
- write_data_and_target(current_output_folder, data_fileList, data, targets, mode)
- copy_meta_files(input_data_folder, current_output_folder)
-
- data_split_process()
|