|
- import numpy as np
- import imageio
- import os
- import random
- import argparse
-
- # make a data list for train , validation and test
- # 抽取0.1用于test
- # 制作训练的麻烦
-
- def img_dir_filter(f):
- if f[-4:] in ['.jpg', '.png', '.bmp']:
- return True
- else:
- return False
-
-
- def data_split(full_list, ratio, shuffle=True):
- """
- 数据集拆分: 将列表full_list按比例ratio(随机)划分为2个子列表sublist_1与sublist_2
- :param full_list: 数据列表
- :param ratio: 子列表1
- :param shuffle: 子列表2
- :return:
- """
- n_total = len(full_list)
- offset = int(n_total * ratio)
- if n_total == 0 or offset < 1:
- return [], full_list
- if shuffle:
- random.shuffle(full_list)
- sublist_1 = full_list[:offset]
- sublist_2 = full_list[offset:]
- return sublist_1, sublist_2
-
-
- def make_pos_sample():
-
- pass
-
- def make_neg_sample():
- pass
-
- def random_sample(root,file_list,num_frames):
- seq_neg = []
- # make label
- seq_neg.append('0')
- for i in range(num_frames):
- file = file_list[random.randint(0,len(file_list) - 1)]
- img_list = os.listdir(os.path.join(root,file))
- img_index = random.randint(0, len(img_list) - 1)
- img_path = "img00" + str(img_index) + '.png'
- seq_neg.append(os.path.join(file, img_path))
- return seq_neg
-
- def make_train_list(root,file_list,num_frames):
- seq_list = []
- for i,file in enumerate(file_list):
- img_list = os.listdir(os.path.join(root,file))
- # 过滤
- img_list = list(filter(img_dir_filter, img_list))
- # 排序
- img_list = sorted(img_list, key=str.lower)
- for j,img_name in enumerate(img_list):
- if j + num_frames < len(img_list):
- seq_pos = []
- # make label
- seq_pos.append('1')
- for frame_index in range(num_frames):
- seq_pos.append(os.path.join(file, img_list[j + frame_index]))
- seq_list.append(seq_pos)
- # todo: 制作neg sample
- seq_neg = random_sample(root, file_list, num_frames)
- seq_list.append(seq_neg)
- train_seq_list , val_seq_list = data_split(seq_list, ratio = 0.8)
- return train_seq_list, val_seq_list
-
- def make_test_list(root,file_list,num_frames):
- seq_list = []
- for i,file in enumerate(file_list):
- img_list = os.listdir(os.path.join(root,file))
- # 过滤
- img_list = list(filter(img_dir_filter, img_list))
- # 排序
- img_list = sorted(img_list, key=str.lower)
- for j,img_name in enumerate(img_list):
- if j + num_frames < len(img_list):
- seq_pos = []
- # make label
- seq_pos.append('1')
- for frame_index in range(num_frames):
- seq_pos.append(os.path.join(file, img_list[j + frame_index]))
- seq_list.append(seq_pos)
- return seq_list
-
- def save_seq_list(seq_list , save_path):
- with open(save_path, 'w') as f:
- for i in seq_list:
- for j in i:
- f.write(j)
- f.write(' ')
- f.write('\n')
- f.close()
-
- def check_data(root,num_frames):
- safe_file_list = []
- file_list = os.listdir(root)
- print("******* empty or lack data ********")
- for i,file in enumerate(file_list):
- img_list = os.listdir(os.path.join(root,file))
- if len(img_list) - 1 < num_frames:
- print(file)
- else:
- safe_file_list.append(file)
- return safe_file_list
-
- def main(args):
- root = args.dir
- num_frames = args.num_frames
- save_dir = args.save_dir
- # file_list = os.listdir(root)
- file_list = check_data(root,num_frames)
-
- num_file = len(file_list)
- num_test_file = int(num_file * 0.1)
- # num_train_file = num_file - num_test_file
- print("*****start to make train seq list*****")
- train_file_list = file_list[:-1 * num_test_file]
- train_seq_list, val_seq_list = make_train_list(root,train_file_list, num_frames)
- save_seq_list(train_seq_list,os.path.join(save_dir,'train.txt'))
- save_seq_list(val_seq_list,os.path.join(save_dir,'val.txt'))
- print("*****End for train seq list*****")
- # todo:test
- print("*****start to make test seq list*****")
- test_file_list = file_list[-1 * num_test_file:]
- test_seq_list = make_test_list(root,test_file_list, num_frames)
- save_seq_list(test_seq_list,os.path.join(save_dir,'test.txt'))
- print("*****End for test seq list*****")
-
-
-
- def load_data_trial(root):
- file = open(root, mode='r', encoding='UTF-8')
- data = []
- contents = file.readlines()
- # print(contents)
- for content_line in contents:
- content_line = content_line.strip('\n')
- item = content_line.split(' ')
- if item[-1] == '':
- item = item[:-1]
- data.append(item)
- file.close()
- print(data[0])
-
-
- if __name__ == "__main__":
- args = argparse.ArgumentParser()
- args.add_argument("--dir",default = "/media/alpha4TB/mayongjia/research/LvShuai/sketch/datasets/PicExport新!/")
- args.add_argument("--save_dir",default = "../datasets/")
- args.add_argument("--num_frames",default = 8 ,type= int)
- args.add_argument("--load_data_trial", action='store_true')
- opt = args.parse_args()
- if opt.load_data_trial:
- load_data_trial(os.path.join(opt.save_dir, 'test.txt'))
- else:
- main(opt)
|