|
|
@@ -0,0 +1,57 @@ |
|
|
|
# coding: utf-8
|
|
|
|
from PIL import Image
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
import os.path as osp
|
|
|
|
# import torch
|
|
|
|
import json
|
|
|
|
# import prototype.spring.linklink as link
|
|
|
|
# import os
|
|
|
|
|
|
|
|
# 集成Dataset类
|
|
|
|
class ImageDataset(Dataset):
|
|
|
|
def __init__(self, root, txt_path, transform=None, target_transform=None):
|
|
|
|
"""
|
|
|
|
tex_path : txt文本路径,该文本包含了图像的路径信息,以及标签信息
|
|
|
|
transform:数据处理,对图像进行随机剪裁,以及转换成tensor
|
|
|
|
"""
|
|
|
|
self.root = root
|
|
|
|
self.transform = transform
|
|
|
|
# self.evaluator = evaluator
|
|
|
|
imgs = []
|
|
|
|
|
|
|
|
with open(txt_path) as f:
|
|
|
|
lines = f.readlines()
|
|
|
|
|
|
|
|
self.num = len(lines)
|
|
|
|
self.metas = []
|
|
|
|
for line in lines:
|
|
|
|
# filename, label = line.rstrip().split()
|
|
|
|
# self.metas.append({'filename': filename, 'label': label})
|
|
|
|
# imgs.append((filename, int(label)))
|
|
|
|
info = json.loads(line) #适应RobustART的标签文件
|
|
|
|
self.metas.append(info)
|
|
|
|
# self.imgs = imgs
|
|
|
|
self.transform = transform
|
|
|
|
self.target_transform = target_transform
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
#--------适应RobustART-----------
|
|
|
|
curr_meta = self.metas[index] #适应RobustART
|
|
|
|
filename = osp.join(self.root, curr_meta['filename'])
|
|
|
|
curr_meta['filename'] = filename
|
|
|
|
label = int(curr_meta['label']) if 'label' in curr_meta else 0
|
|
|
|
#-------适应RobustART的标签文件------
|
|
|
|
# filename, label = self.imgs[index]
|
|
|
|
# filename = osp.join(self.root, filename)
|
|
|
|
# label = int(label)
|
|
|
|
|
|
|
|
img = Image.open(filename).convert('RGB') # 把图像转成RGB
|
|
|
|
if self.transform is not None:
|
|
|
|
img = self.transform(img)
|
|
|
|
if self.target_transform is not None:
|
|
|
|
label = self.transform(label)
|
|
|
|
|
|
|
|
return img, label# 这就返回一个样本
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return self.num # 返回长度,index就会自动的指导读取多少
|