|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2022/12/19
- -----------------------------------
- """
- import pickle
- from tqdm import tqdm
-
- import numpy as np
- import torch
- from transformers import ViTImageProcessor
- from torch.utils.data.dataset import Dataset
- from torch.utils.data.dataloader import DataLoader
-
- from config import args, logger
-
-
- # file = "D:/data/cifar-100-python/test"
- # with open(file, 'rb') as fo:
- # dic = pickle.load(fo, encoding='bytes')
- # print(dic.keys())
- # print(len(dic[b'filenames']))
-
-
- # define a random dataset
- class Cifar100Dataset(Dataset):
- def __init__(self, pretrain_dir):
- super().__init__()
- # self.sample = []
- self.pixel_values = []
- self.labels = []
-
- self.feature_extractor = ViTImageProcessor.from_pretrained(pretrain_dir)
- self.label_nums = 100
-
- def read_file(self, data_path):
- with open(data_path, 'rb') as fo:
- dic = pickle.load(fo, encoding='bytes')
- fine_labels = dic[b"fine_labels"]
- # coarse_labels = dic[b"coarse_labels"]
- images = dic[b"data"]
-
- for idx, _ in tqdm(enumerate(images)):
- img_reshaped = np.transpose(np.reshape(images[idx], (3, 32, 32)), (1, 2, 0))
- feature = self.feature_extractor(img_reshaped, return_tensors="pt")
-
- self.pixel_values.append(feature['pixel_values'])
- self.labels.append(torch.tensor([fine_labels[idx]], dtype=torch.int64))
-
- def save(self, filename):
- dic = {
- "pixel_values": self.pixel_values,
- "labels": self.labels,
- }
- with open(filename, 'wb') as f:
- pickle.dump(dic, f)
-
- def load(self, filename):
- with open(filename, 'rb') as f:
- dic = pickle.load(f)
- self.pixel_values = dic["pixel_values"]
- self.labels = dic["labels"]
-
- def __getitem__(self, idx):
- return {
- "pixel_values": self.pixel_values[idx],
- "labels": self.labels[idx],
- }
-
- def __len__(self):
- return len(self.pixel_values)
-
- @staticmethod
- def collate(batch_data):
- batch_pixel_values = torch.cat([data["pixel_values"] for data in batch_data], dim=0)
- batch_labels = torch.cat([data["labels"] for data in batch_data], dim=0)
-
- return {
- "batch_pixel_values": batch_pixel_values,
- "batch_labels": batch_labels,
- }
-
-
- if __name__ == '__main__':
- train_data = Cifar100Dataset(args.pretrain_dir)
- train_data.read_file(args.test_path)
- train_data.save("test.dat")
- train_data.load("test.dat")
- train_loader = DataLoader(train_data, batch_size=3, collate_fn=train_data.collate)
-
- print(len(train_data))
- print(len(train_loader))
-
- # for batch in train_loader:
- # print(batch)
- # break
|