#2 上传文件至 'data'

Merged
alpha_magic merged 1 commits from magic_liu-patch-2 into master 1 year ago
  1. +57
    -0
      data/image_dataset.py

+ 57
- 0
data/image_dataset.py View File

@@ -0,0 +1,57 @@
# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
import os.path as osp
# import torch
import json
# import prototype.spring.linklink as link
# import os
# 集成Dataset类
class ImageDataset(Dataset):
def __init__(self, root, txt_path, transform=None, target_transform=None):
"""
tex_path : txt文本路径,该文本包含了图像的路径信息,以及标签信息
transform:数据处理,对图像进行随机剪裁,以及转换成tensor
"""
self.root = root
self.transform = transform
# self.evaluator = evaluator
imgs = []
with open(txt_path) as f:
lines = f.readlines()
self.num = len(lines)
self.metas = []
for line in lines:
# filename, label = line.rstrip().split()
# self.metas.append({'filename': filename, 'label': label})
# imgs.append((filename, int(label)))
info = json.loads(line) #适应RobustART的标签文件
self.metas.append(info)
# self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
#--------适应RobustART-----------
curr_meta = self.metas[index] #适应RobustART
filename = osp.join(self.root, curr_meta['filename'])
curr_meta['filename'] = filename
label = int(curr_meta['label']) if 'label' in curr_meta else 0
#-------适应RobustART的标签文件------
# filename, label = self.imgs[index]
# filename = osp.join(self.root, filename)
# label = int(label)
img = Image.open(filename).convert('RGB') # 把图像转成RGB
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
label = self.transform(label)
return img, label# 这就返回一个样本
def __len__(self):
return self.num # 返回长度,index就会自动的指导读取多少

Loading…
Cancel
Save