|
- import os, sys, hashlib
- import numpy as np
- from PIL import Image
- import torch.utils.data as data
- import pickle
-
-
- # def ImageNet16Loader():
- # mean = [x / 255 for x in [122.68, 116.66, 104.01]]
- # std = [x / 255 for x in [63.22, 61.26, 65.09]]
- # lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
- # train_transform = transforms.Compose(lists)
- # train_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120)
- # assert len(train_data) == 151700
-
- def calculate_md5(fpath, chunk_size=1024 * 1024):
- md5 = hashlib.md5()
- with open(fpath, "rb") as f:
- for chunk in iter(lambda: f.read(chunk_size), b""):
- md5.update(chunk)
- return md5.hexdigest()
-
-
- def check_md5(fpath, md5, **kwargs):
- return md5 == calculate_md5(fpath, **kwargs)
-
-
- def check_integrity(fpath, md5=None):
- if not os.path.isfile(fpath):
- return False
- if md5 is None:
- return True
- else:
- return check_md5(fpath, md5)
-
-
- class ImageNet16(data.Dataset):
- # http://image-net.org/download-images
- # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
- # https://arxiv.org/pdf/1707.08819.pdf
-
- train_list = [
- ["train_data_batch_1", "27846dcaa50de8e21a7d1a35f30f0e91"],
- ["train_data_batch_2", "c7254a054e0e795c69120a5727050e3f"],
- ["train_data_batch_3", "4333d3df2e5ffb114b05d2ffc19b1e87"],
- ["train_data_batch_4", "1620cdf193304f4a92677b695d70d10f"],
- ["train_data_batch_5", "348b3c2fdbb3940c4e9e834affd3b18d"],
- ["train_data_batch_6", "6e765307c242a1b3d7d5ef9139b48945"],
- ["train_data_batch_7", "564926d8cbf8fc4818ba23d2faac7564"],
- ["train_data_batch_8", "f4755871f718ccb653440b9dd0ebac66"],
- ["train_data_batch_9", "bb6dd660c38c58552125b1a92f86b5d4"],
- ["train_data_batch_10", "8f03f34ac4b42271a294f91bf480f29b"],
- ]
- valid_list = [
- ["val_data", "3410e3017fdaefba8d5073aaa65e4bd6"],
- ]
-
- def __init__(self, root, train, transform, use_num_of_class_only=None):
- self.root = root
- self.transform = transform
- self.train = train # training set or valid set
- if not self._check_integrity():
- raise RuntimeError("Dataset not found or corrupted.")
-
- if self.train:
- downloaded_list = self.train_list
- else:
- downloaded_list = self.valid_list
- self.data = []
- self.targets = []
-
- # now load the picked numpy arrays
- for i, (file_name, checksum) in enumerate(downloaded_list):
- file_path = os.path.join(self.root, file_name)
- # print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
- with open(file_path, "rb") as f:
- if sys.version_info[0] == 2:
- entry = pickle.load(f)
- else:
- entry = pickle.load(f, encoding="latin1")
- self.data.append(entry["data"])
- self.targets.extend(entry["labels"])
- self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
- self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
- if use_num_of_class_only is not None:
- assert (
- isinstance(use_num_of_class_only, int)
- and use_num_of_class_only > 0
- and use_num_of_class_only < 1000
- ), "invalid use_num_of_class_only : {:}".format(use_num_of_class_only)
- new_data, new_targets = [], []
- for I, L in zip(self.data, self.targets):
- if 1 <= L <= use_num_of_class_only:
- new_data.append(I)
- new_targets.append(L)
- self.data = new_data
- self.targets = new_targets
- # self.mean.append(entry['mean'])
- # self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
- # self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
- # print ('Mean : {:}'.format(self.mean))
- # temp = self.data - np.reshape(self.mean, (1, 1, 1, 3))
- # std_data = np.std(temp, axis=0)
- # std_data = np.mean(np.mean(std_data, axis=0), axis=0)
- # print ('Std : {:}'.format(std_data))
-
- def __getitem__(self, index):
- img, target = self.data[index], self.targets[index] - 1
- img = Image.fromarray(img)
- if self.transform is not None:
- img = self.transform(img)
- return img, target
-
- def __len__(self):
- return len(self.data)
-
- def _check_integrity(self):
- root = self.root
- for fentry in self.train_list + self.valid_list:
- filename, md5 = fentry[0], fentry[1]
- fpath = os.path.join(root, filename)
- if not check_integrity(fpath, md5):
- return False
- return True
|