|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """FOTS dataset."""
- import os
- import multiprocessing
- import random
- import numpy as np
- import mindspore.numpy as ms_np
- import cv2
- import re
- from PIL import Image
- import mindspore.dataset as de
- import mindspore.dataset.vision.c_transforms as c_transforms
- import mindspore.dataset.vision.py_transforms as py_transforms
- from mindspore.dataset.transforms.py_transforms import Compose
- from src.distributed_sampler import DistributedSampler
- from src.config import configFOTS as config
- from src.transforms import transform
-
- from mindspore import ops
- import mindspore
- min_keypoints_per_image = 10
- GENERATOR_PARALLEL_WORKER = 1
-
-
- class ICDAR2015Dataset_train:
- """FOTS training Dataset for ICDAR2015."""
- def __init__(self, img_dir, ann_file, transform):
- self.transform = transform
- self.img_dir = img_dir
- self.labels_dir = ann_file
- self.image_prefix = []
- self.pattern = re.compile('^' + '(\\d+),' * 8 + '(.+)$')
- for dirEntry in os.scandir(self.img_dir):
- self.image_prefix.append(dirEntry.name[:-4])
-
- self.count = 0
- self.stack = ops.Stack()
-
-
- def __getitem__(self, idx):
- """
- Args:
- idx (int): Index
-
- Returns:
- (image, classification, regression, thetas, training_mask) (tuple): .
- """
- img = cv2.imread(os.path.join(self.img_dir, self.image_prefix[idx] + '.jpg'), cv2.IMREAD_COLOR).astype(np.float32)
- quads = []
- texts = []
- lines = [line.rstrip('\n') for line in open(os.path.join(self.labels_dir, 'gt_' + self.image_prefix[idx] + '.txt'),
- encoding='utf-8-sig')]
- for line in lines:
- matches = self.pattern.findall(line)[0]
-
- numbers = np.array(matches[:8], dtype=float)
- quads.append(numbers.reshape((4, 2)))
- texts.append('###' != matches[8])
- image, classification, regression, thetas, training_mask = transform(img, np.stack(quads), texts, self)
- # return image, classification, regression, thetas, training_mask
- # return (np.array(img), np.array(img), np.array(img), np.array(img), np.array(img))
- return (image, classification, regression, thetas, training_mask)
-
-
- def __len__(self):
- return len(self.image_prefix)
-
-
- class ICDAR2015Dataset_test:
- """FOTS testing Dataset for ICDAR2015."""
- def __init__(self, img_dir, ann_file):
- self.img_dir = img_dir
- self.labels_dir = ann_file
- self.image_prefix = []
- self.pattern = re.compile('^' + '(\\d+),' * 8 + '(.+)$')
- for dirEntry in os.scandir(self.img_dir):
- self.image_prefix.append(dirEntry.name[:-4])
-
- self.count = 0
-
-
- def __getitem__(self, idx):
- """
- Args:
- idx (int): Index
-
- Returns:
- (scaled_image, quads, texts) (tuple): .
- """
- # 图像
- image = cv2.imread(os.path.join(self.img_dir, self.image_prefix[idx] + '.jpg'), cv2.IMREAD_COLOR).astype(np.float32)
-
- scale_x = 2240 / image.shape[1] # 2240 # 1280
- scale_y = 1248 / image.shape[0] # 1248 # 704
-
- scaled_image = cv2.resize(image, dsize=(0, 0), fx=scale_x, fy=scale_y, interpolation=cv2.INTER_CUBIC)
- orig_scaled_image = scaled_image.copy()
-
- scaled_image = scaled_image[:, :, ::-1].astype(np.float32)
- scaled_image = ((scaled_image / 255 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])).astype(np.float32)
-
- # 文本
- # quads = []
- # texts = []
- # lines = [line.rstrip('\n') for line in open(os.path.join(self.labels_dir, 'gt_' + self.image_prefix[idx] + '.txt'),
- # encoding='utf-8-sig')]
- # for line in lines:
- # matches = self.pattern.findall(line)[0]
- # numbers = np.array(matches[:8], dtype=float)
- # quads.append(numbers.reshape((4, 2)))
- # texts.append('###' != matches[8])
- # # quads:每个文本的gt坐标;texts:每个文本到底有无内容
- return (scaled_image, np.array(idx, dtype=np.float32), np.array([scale_x, scale_y], dtype=np.float32))
-
-
- def __len__(self):
- return len(self.image_prefix)
-
-
-
-
- def create_fots_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank,
- config=None, is_training=True, shuffle=True):
- """Create dataset for FOTS."""
- cv2.setNumThreads(0)
- # de.config.set_enable_shared_mem(True)
- if is_training:
- filter_crowd = True
- remove_empty_anno = True
- else:
- filter_crowd = False
- remove_empty_anno = False
-
-
- if is_training:
- fots_dataset = ICDAR2015Dataset_train(img_dir=image_dir, ann_file=anno_path, transform=transform)
-
- distributed_sampler = DistributedSampler(len(fots_dataset), device_num, rank, shuffle=shuffle)
- fots_dataset.size = len(distributed_sampler)
-
- config.dataset_size = len(fots_dataset)
- cores = multiprocessing.cpu_count()
- num_parallel_workers = GENERATOR_PARALLEL_WORKER
- sampler = de.RandomSampler()
-
- dataset_column_names = ["image", "classification", "regression", "thetas", "training_mask"]
-
- ds = de.GeneratorDataset(fots_dataset, column_names=dataset_column_names, sampler=sampler,
- num_parallel_workers=GENERATOR_PARALLEL_WORKER)
-
- ds = ds.batch(batch_size, num_parallel_workers=GENERATOR_PARALLEL_WORKER, drop_remainder=True)
-
- ds = ds.repeat(max_epoch)
- return ds, len(fots_dataset)
-
- else:
- fots_dataset = ICDAR2015Dataset_test(img_dir=image_dir, ann_file=anno_path)
- distributed_sampler = DistributedSampler(len(fots_dataset), device_num, rank, shuffle=shuffle)
- fots_dataset.size = len(distributed_sampler)
-
- config.dataset_size = len(fots_dataset)
- cores = multiprocessing.cpu_count()
- num_parallel_workers = GENERATOR_PARALLEL_WORKER
- sampler = de.RandomSampler()
-
- dataset_column_names = ["image", "img_idx", "scale"]
- ds = de.GeneratorDataset(fots_dataset, column_names=dataset_column_names, sampler=sampler,
- num_parallel_workers=GENERATOR_PARALLEL_WORKER)
- ds = ds.batch(batch_size, num_parallel_workers=GENERATOR_PARALLEL_WORKER, drop_remainder=True)
-
- ds = ds.repeat(max_epoch)
- return ds, fots_dataset
-
-
- if __name__ == '__main__':
- image_dir = 'dataset/ICDAR2015/task4_1/ch4_training_images'
- anno_path = 'dataset/ICDAR2015/task4_1/ch4_training_localization_transcription_gt'
- fots_dataset = ICDAR2015Dataset_train(img_dir=image_dir, ann_file=anno_path, transform=transform)
-
- # image, classification, regression, thetas, training_mask = fots_dataset.__getitem__(10)
- # print('image:')
- # print(image)
-
- # print('classification')
- # print(classification)
-
- ds = create_fots_dataset(image_dir, anno_path, 32, 10, 1, 0, config)
- for x in ds[0].create_dict_iterator():
- print(type(x), x.keys())
- for k, v in x.items():
- print(k, v.shape)
- break
|