From e676710c291386cf5a696e384e29a18ad9451919 Mon Sep 17 00:00:00 2001 From: magic_liu <76960446@qq.com> Date: Mon, 6 Jun 2022 21:09:55 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20'data'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/image_dataset.py | 57 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 data/image_dataset.py diff --git a/data/image_dataset.py b/data/image_dataset.py new file mode 100644 index 0000000..5b7b7e9 --- /dev/null +++ b/data/image_dataset.py @@ -0,0 +1,57 @@ +# coding: utf-8 +from PIL import Image +from torch.utils.data import Dataset +import os.path as osp +# import torch +import json +# import prototype.spring.linklink as link +# import os + +# 集成Dataset类 +class ImageDataset(Dataset): + def __init__(self, root, txt_path, transform=None, target_transform=None): + """ + tex_path : txt文本路径,该文本包含了图像的路径信息,以及标签信息 + transform:数据处理,对图像进行随机剪裁,以及转换成tensor + """ + self.root = root + self.transform = transform + # self.evaluator = evaluator + imgs = [] + + with open(txt_path) as f: + lines = f.readlines() + + self.num = len(lines) + self.metas = [] + for line in lines: +# filename, label = line.rstrip().split() + # self.metas.append({'filename': filename, 'label': label}) +# imgs.append((filename, int(label))) + info = json.loads(line) #适应RobustART的标签文件 + self.metas.append(info) +# self.imgs = imgs + self.transform = transform + self.target_transform = target_transform + + def __getitem__(self, index): + #--------适应RobustART----------- + curr_meta = self.metas[index] #适应RobustART + filename = osp.join(self.root, curr_meta['filename']) + curr_meta['filename'] = filename + label = int(curr_meta['label']) if 'label' in curr_meta else 0 + #-------适应RobustART的标签文件------ +# filename, label = self.imgs[index] +# filename = osp.join(self.root, filename) +# label = int(label) + + img = Image.open(filename).convert('RGB') # 把图像转成RGB + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + label = self.transform(label) + + return img, label# 这就返回一个样本 + + def __len__(self): + return self.num # 返回长度,index就会自动的指导读取多少 -- 2.34.1