|
- import torch
- import numpy as np
- import json
- import os
-
- class HybridLoader:
- """
- If db_path is a director, then use normal file loading
- If lmdb, then load from lmdb
- The loading method depend on extention.
- """
- def __init__(self, db_path, ext):
- self.db_path = db_path
- self.ext = ext
- if self.ext == '.npy':
- self.loader = lambda x: np.load(x)
- else:
- self.loader = lambda x: np.load(x)['feat']
- if db_path.endswith('.lmdb'):
- self.db_type = 'lmdb'
- self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path),
- readonly=True, lock=False,
- readahead=False, meminit=False)
- elif db_path.endswith('.pth'): # Assume a key,value dictionary
- self.db_type = 'pth'
- self.feat_file = torch.load(db_path)
- self.loader = lambda x: x
- print('HybridLoader: ext is ignored')
- else:
- self.db_type = 'dir'
-
- def get(self, key):
-
- if self.db_type == 'lmdb':
- env = self.env
- with env.begin(write=False) as txn:
- byteflow = txn.get(key)
- f_input = six.BytesIO(byteflow)
- elif self.db_type == 'pth':
- f_input = self.feat_file[key]
- else:
- f_input = os.path.join(self.db_path, key + self.ext)
-
- # load image
- feat = self.loader(f_input)
-
- return feat
-
- info = json.load(open("data/cocotalk.json"))
- fc_loader = HybridLoader("data/cocobu_fc", '.npy')
- att_loader = HybridLoader("data/cocobu_att", '.npz')
-
- #model=ImageCaptionModel()
-
- for i in range(4):
- fc_feat = fc_loader.get(str(info['images'][i]['id']))
- att_feat = att_loader.get(str(info['images'][i]['id']))
- print(type(fc_feat),type(att_feat))
- #begin=time.time()
- fc_feat=torch.tensor(fc_feat).float()
- att_feat=torch.tensor(att_feat).float()
- #att_mask=torch.ones((1,att_feat.size(0)))
-
- fc_np = fc_feat.numpy()
- att_np = att_feat.numpy()
- #mask_np = att_mask.numpy()
-
- np.save("data/tmp_data/fc_np"+str(i)+".npy",fc_np)
- np.save("data/tmp_data/att_np"+str(i)+".npy",att_np)
- #np.save("data/tmp_data/mask_np"+str(i)+".npy",mask_np)
- #sents=model(fc_feat,att_feat,att_mask)[0]
-
- #end=time.time()
-
- #print("Output: {}\nTime={}s".format(sents,round(end-begin,4)))
|