|
- """COCO dataset loader"""
- import torch
- import torch.utils.data as data
- import os
- import os.path as osp
- import numpy as np
- import random
- import nltk
-
- import logging
-
- logger = logging.getLogger(__name__)
-
-
- class PrecompRegionDataset(data.Dataset):
- """
- Load precomputed captions and image features for COCO or Flickr
- """
-
- def __init__(self, data_path, data_name, data_split, vocab, opt, train):
- self.vocab = vocab
- self.opt = opt
- self.train = train
- self.data_path = data_path
- self.data_name = data_name
-
- loc_cap = osp.join(data_path, data_name)
- loc_image = osp.join(data_path, data_name)
-
- # Captions
- self.captions = []
- with open(osp.join(loc_cap, '%s_caps.txt' % data_split), 'rb') as f:
- for line in f:
- self.captions.append(line.strip().decode('utf-8'))
- # Image features
- self.images = np.load(os.path.join(loc_image, '%s_ims.npy' % data_split), allow_pickle=True)
-
- self.length = len(self.captions)
- # rkiros data has redundancy in images, we divide by 5, 10crop doesn't
- num_images = len(self.images)
-
- if num_images != self.length:
- self.im_div = 5
- else:
- self.im_div = 1
- # the development set for coco is large and so validation would be slow
- if data_split == 'dev':
- self.length = 5000
- # self.length = 500
-
- def __getitem__(self, index):
- # handle the image redundancy
- img_index = index // self.im_div
- caption = self.captions[index]
-
- # Convert caption (string) to word ids (with Size Augmentation at training time).
- target = process_caption(self.vocab, caption, self.train)
- image = self.images[img_index]
- # if self.train: # Size augmentation on region features.
- # num_features = image.shape[0]
- # rand_list = np.random.rand(num_features)
- # image = image[np.where(rand_list > 0.20)]
- # np.random.shuffle(image)
- image = torch.Tensor(image)
- return image, target, index, img_index
-
- def __len__(self):
- return self.length
-
-
- def process_image(image, train):
- if train: # Size augmentation on region features.
- num_features = image.shape[0]
- rand_list = np.random.rand(num_features)
- image = image[np.where(rand_list > 0.20)]
- np.random.shuffle(image)
-
- return image
-
-
- def process_caption(vocab, caption, drop=False):
- if not drop:
- tokens = nltk.tokenize.word_tokenize(caption.lower())
- caption = list()
- caption.append(vocab('<start>'))
- # tokens_split = [t.split('-') for t in tokens]
- caption.extend([vocab(token) for token in tokens])
- caption.append(vocab('<end>'))
- target = torch.Tensor(caption)
- return target
- else:
- # Convert caption (string) to word ids.
- tokens = ['<start>', ]
- tks = nltk.tokenize.word_tokenize(caption.lower())
- # tokens_split = [t.split('-') for t in tks]
- tokens.extend([token for token in tks])
- tokens.append('<end>')
- deleted_idx = []
- for i, token in enumerate(tokens):
- prob = random.random()
- if prob < 0.20:
- prob /= 0.20
- # 50% randomly change token to mask token
- if prob < 0.5:
- tokens[i] = vocab.word2idx['<mask>']
- # 10% randomly change token to random token
- elif prob < 0.6:
- tokens[i] = random.randrange(len(vocab))
- # 40% randomly remove the token
- else:
- tokens[i] = vocab(token)
- deleted_idx.append(i)
- else:
- tokens[i] = vocab(token)
- if len(deleted_idx) != 0:
- tokens = [tokens[i] for i in range(len(tokens)) if i not in deleted_idx]
- target = torch.Tensor(tokens)
- return target
-
-
- def process_image_collection(images):
- img_lengths = [len(image) for image in images]
- all_images = torch.zeros(len(images), max(img_lengths), images[0].size(-1))
- for i, image in enumerate(images):
- end = img_lengths[i]
- all_images[i, :end] = image[:end]
- img_lengths = torch.Tensor(img_lengths)
-
- return all_images, img_lengths
-
-
- def get_relation(image):
- image = torch.nn.functional.normlize(image, p=2, dim=-1)
- att = image.mm(image.t())
-
-
- def process_image_collection2(images, train):
- img_p1_, img_p2_, len_p1, len_p2 = [], [], [], []
- # img_p3_, img_p4_, len_p3, len_p4 = [], [], [], []
- for i, image in enumerate(images):
- att = get_relation(image)
- p1_list = np.random.choice(image.size(0), 5, replace=False)
- p2_list = np.random.choice(image.size(0), 5, replace=False)
-
- p1 = image[att[p1_list,:].topk(5)[1]].reshape(-1, image.shape[1])
- p2 = image[att[p2_list, :].topk(5)[1]].reshape(-1, image.shape[1])
-
- p1 = process_image(p1, train)
- p2 = process_image(p2, train)
-
- img_p1_.append(p1)
- len_p1.append(p1.size(0))
-
- img_p2_.append(p2)
- len_p2.append(p2.size(0))
-
-
- img_p1 = torch.zeros(len(img_p1_), max(len_p1), img_p1_[0].size(-1))
- for i, img1 in enumerate(img_p1_):
- end = len_p1[i]
- img_p1[i, :end] = img1[:end]
-
- img_p2 = torch.zeros(len(img_p2_), max(len_p2), img_p2_[0].size(-1))
- for i, img2 in enumerate(img_p2_):
- end = len_p2[i]
- img_p2[i, :end] = img2[:end]
-
- len_p1 = torch.Tensor(len_p1)
- len_p2 = torch.Tensor(len_p2)
-
- return [img_p1, img_p2], [len_p1, len_p2]
-
-
- def collate_fn(data):
- """Build mini-batch tensors from a list of (image, caption) tuples.
- Args:
- data: list of (image, caption) tuple.
- - image: torch tensor of shape (3, 256, 256).
- - caption: torch tensor of shape (?); variable length.
-
- Returns:
- images: torch tensor of shape (batch_size, 3, 256, 256).
- targets: torch tensor of shape (batch_size, padded_length).
- lengths: list; valid length for each padded caption.
- """
- # Sort a data list by caption length
- data.sort(key=lambda x: len(x[1]), reverse=True)
- images, captions, ids, img_ids = zip(*data)
-
- # Merge images
- all_images, img_lengths = process_image_collection(images)
-
- # Merget captions
- lengths = [len(cap) for cap in captions]
- targets = torch.zeros(len(captions), max(lengths)).long()
-
- for i, cap in enumerate(captions):
- end = lengths[i]
- targets[i, :end] = cap[:end]
-
- return all_images, img_lengths, targets, lengths, ids
-
-
-
- def get_loader(data_path, data_name, data_split, vocab, opt, batch_size=100,
- shuffle=True, num_workers=2, train=True):
- """Returns torch.utils.data.DataLoader for custom coco dataset."""
- if train:
- drop_last = True
- else:
- drop_last = False
-
- dset = PrecompRegionDataset(data_path, data_name, data_split, vocab, opt, train)
- data_loader = torch.utils.data.DataLoader(dataset=dset,
- batch_size=batch_size,
- shuffle=shuffle,
- pin_memory=True,
- collate_fn=collate_fn,
- num_workers=num_workers,
- drop_last=drop_last)
- return data_loader
-
-
- def get_loaders(data_path, data_name, vocab, batch_size, workers, opt):
- train_loader = get_loader(data_path, data_name, 'train', vocab, opt,
- batch_size, True, workers)
- val_loader = get_loader(data_path, data_name, 'test', vocab, opt,
- batch_size, False, workers, train=False)
- return train_loader, val_loader
-
-
- def get_train_loader(data_path, data_name, vocab, batch_size, workers, opt, shuffle):
- train_loader = get_loader(data_path, data_name, 'train', vocab, opt,
- batch_size, shuffle, workers)
- return train_loader
-
-
- def get_test_loader(split_name, data_name, vocab, batch_size, workers, opt):
- test_loader = get_loader(opt.data_path, data_name, split_name, vocab, opt,
- batch_size, False, workers, train=False)
- return test_loader
|