|
- import os
- from PIL import Image
- import random
- import numpy as np
- import torch
- import torchvision.transforms as tfs
- import random
-
- class casia2(object):
- def __init__(self, mode="train", newsize=320):
- super(casia2, self).__init__()
-
- self.images = []
- self.masks = []
-
- self.newsize = newsize
-
- self.im_tfs = tfs.Compose([tfs.ToTensor()])
-
- if(mode == "train"):
- self.readImage_train()
- else:
- self.readImage_val()
-
- def readImage_train(self):
-
- lines = []
- with open("./casia2_train.txt", "r") as f:
- for line in f.readlines():
- line = line.strip('\n')
- lines.append(line)
- random.shuffle(lines)
-
- for line in lines:
- self.images.append("/dataset/casiav2/images/" + line)
- self.masks.append("/dataset/casiav2/masks/" + line[:-4] +"_gt.png")
-
- print("images number:%d"%(len(self.images)))
-
- def readImage_val(self):
-
- lines = []
- with open("./casia2_val.txt", "r") as f:
- for line in f.readlines():
- line = line.strip('\n')
- lines.append(line)
- random.shuffle(lines)
-
- for line in lines:
- self.images.append("/dataset/casiav2/images/" + line)
- self.masks.append("/dataset/casiav2/masks/" + line[:-4] +"_gt.png")
-
- print("images number:%d"%(len(self.images)))
-
- def image_transform(self, img, img_gt):
- img = self.im_tfs(img)
- img_gt = self.im_tfs(img_gt)
- return img, img_gt
-
- def __getitem__(self, idx):
- img = Image.open(self.images[idx]).convert('RGB').resize((self.newsize, self.newsize))
- img_gt = Image.open(self.masks[idx]).convert('L').resize((self.newsize, self.newsize))
-
- img, img_gt = self.image_transform(img, img_gt)
- return img, img_gt
-
- def __len__(self):
- return len(self.images)
-
-
- if __name__ == '__main__':
-
- val = casia2("train", newsize=320)
- val_loader = torch.utils.data.DataLoader(val, batch_size=10, shuffle=True)
- for data in val_loader:
- img, img_gt = data
- print(img.shape)
-
- '''
- lines = []
- with open("./list.txt", "r") as f:
- for line in f.readlines():
- line = line.strip('\n')
- lines.append(line)
- random.shuffle(lines)
-
- with open("./casia2_train.txt", "w") as f1:
- with open("./casia2_no.txt", "w") as f2:
- for line in lines:
- aa="/dataset/casiav2/images/" + line
- bb="/dataset/casiav2/masks/" + line[:-4] +"_gt.png"
-
- path1 = os.path.exists(aa)
- path2 = os.path.exists(bb)
- if path2:
- f1.write(line+'\n')
- else:
- f2.write(line+'\n')
-
- lines = []
- with open("./list.txt", "r") as f:
- for line in f.readlines():
- line = line.strip('\n')
- lines.append(line)
- random.shuffle(lines)
-
- with open("./casia2_train.txt", "w") as f1:
- for line in lines:
- f1.write(line+'\n')
- '''
|