|
- import torchvision.transforms as transforms
- from torchvision.datasets import CIFAR10
-
- from torch.utils.data import DataLoader
-
- class MYCIFAR10(CIFAR10):
- def _check_integrity(self):
- return True
-
-
- def load_data(data_root):
- """Load CIFAR-10 (training and test set)."""
-
- transform = transforms.Compose(
- [
- transforms.Resize(224),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
- ]
- )
- transform_test = transforms.Compose([
- transforms.Resize(224),
- transforms.ToTensor(),
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
- ])
-
- trainset = MYCIFAR10(data_root, train=True, download=True, transform=transform)
- testset = MYCIFAR10(data_root, train=False, download=True, transform=transform_test)
-
- num_examples = {"trainset": len(trainset), "testset": len(testset)}
-
-
- return trainset, testset, num_examples
-
- def load_dataloader(data_root, batch_size):
- trainset, testset, num_examples = load_data(data_root)
- trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
- testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
- return trainloader, testloader, num_examples
|