|
- import os
-
- import torch
- from torchvision import datasets, transforms
-
-
- class ColorAugmentation(object):
- def __init__(self, eig_vec=None, eig_val=None):
- if eig_vec == None:
- eig_vec = torch.Tensor([
- [0.4009, 0.7192, -0.5675],
- [-0.8140, -0.0045, -0.5808],
- [0.4203, -0.6948, -0.5836],
- ])
- if eig_val == None:
- eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]])
- self.eig_val = eig_val # 1*3
- self.eig_vec = eig_vec # 3*3
-
- def __call__(self, tensor):
- assert tensor.size(0) == 3
- alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1
- quatity = torch.mm(self.eig_val * alpha, self.eig_vec)
- tensor = tensor + quatity.view(3, 1, 1)
- return tensor
-
-
- class ImageNet:
- def __init__(self, args):
- super(ImageNet, self).__init__()
-
- data_root = args.data
-
- use_cuda = torch.cuda.is_available()
-
- # Data loading code
- kwargs = {"num_workers": args.workers, "pin_memory": True} if use_cuda else {}
-
- # Data loading code
- traindir = os.path.join(data_root, "train")
- valdir = os.path.join(data_root, "val")
-
- normalize = transforms.Normalize(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
- )
-
- train_dataset = datasets.ImageFolder(
- traindir,
- transforms.Compose([
- transforms.RandomResizedCrop(224),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- ColorAugmentation(),
- normalize,
- ]))
-
- self.train_loader = torch.utils.data.DataLoader(
- train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs
- )
-
- self.val_loader = torch.utils.data.DataLoader(
- datasets.ImageFolder(
- valdir,
- transforms.Compose(
- [
- transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- normalize,
- ]
- ),
- ),
- batch_size=args.batch_size // 4,
- shuffle=False,
- **kwargs
- )
|