|
- from typing import Callable, Dict, Optional, Tuple
- import spikingjelly.datasets as sjds
- import scipy.io
- from torchvision.datasets.utils import extract_archive
- import os
- import multiprocessing
- from concurrent.futures import ThreadPoolExecutor
- import time
- import shutil
- from .. import configure
- from ..datasets import np_savez
-
- class ASLDVS(sjds.NeuromorphicDatasetFolder):
- def __init__(
- self,
- root: str,
- data_type: str = 'event',
- frames_number: int = None,
- split_by: str = None,
- duration: int = None,
- custom_integrate_function: Callable = None,
- custom_integrated_frames_dir_name: str = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- ) -> None:
- """
- The ASL-DVS dataset, which is proposed by `Graph-based Object Classification for Neuromorphic Vision Sensing <https://openaccess.thecvf.com/content_ICCV_2019/html/Bi_Graph-Based_Object_Classification_for_Neuromorphic_Vision_Sensing_ICCV_2019_paper.html>`_.
-
- Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information.
- """
- super().__init__(root, None, data_type, frames_number, split_by, duration, custom_integrate_function, custom_integrated_frames_dir_name, transform,
- target_transform)
- @staticmethod
- def resource_url_md5() -> list:
- '''
- :return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5
- :rtype: list
- '''
- print('The ICCV2019_DVS_dataset.zip is packed by dropbox. We find that the MD5 of this zip file can change. So, MD5 check will not be used for this ASL-DVS dataset.')
- url = 'https://www.dropbox.com/sh/ibq0jsicatn7l6r/AACNrNELV56rs1YInMWUs9CAa?dl=0'
- return [
- ('ICCV2019_DVS_dataset.zip', url, None)
- ]
-
- @staticmethod
- def downloadable() -> bool:
- '''
- :return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- :rtype: bool
- '''
- return False
-
- @staticmethod
- def extract_downloaded_files(download_root: str, extract_root: str):
- '''
- :param download_root: Root directory path which saves downloaded dataset files
- :type download_root: str
- :param extract_root: Root directory path which saves extracted files from downloaded files
- :type extract_root: str
- :return: None
-
- This function defines how to extract download files.
- '''
- temp_ext_dir = os.path.join(download_root, 'temp_ext')
- os.mkdir(temp_ext_dir)
- print(f'Mkdir [{temp_ext_dir}].')
- extract_archive(os.path.join(download_root, 'ICCV2019_DVS_dataset.zip'), temp_ext_dir)
- with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 2)) as tpe:
- sub_threads = []
- for zip_file in os.listdir(temp_ext_dir):
- if os.path.splitext(zip_file)[1] == '.zip':
- zip_file = os.path.join(temp_ext_dir, zip_file)
- print(f'Extract [{zip_file}] to [{extract_root}].')
- sub_threads.append(tpe.submit(extract_archive, zip_file, extract_root))
- for sub_thread in sub_threads:
- if sub_thread.exception():
- print(sub_thread.exception())
- exit(-1)
-
-
- shutil.rmtree(temp_ext_dir)
- print(f'Rmtree [{temp_ext_dir}].')
-
- @staticmethod
- def load_origin_data(file_name: str) -> Dict:
- '''
- :param file_name: path of the events file
- :type file_name: str
- :return: a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``
- :rtype: Dict
-
- This function defines how to read the origin binary data.
- '''
- events = scipy.io.loadmat(file_name)
- events = {
- 't': events['ts'].squeeze(),
- 'x': 239 - events['x'].squeeze(),
- 'y': 179 - events['y'].squeeze(),
- 'p': events['pol'].squeeze()
- }
-
- return events
-
- @staticmethod
- def get_H_W() -> Tuple:
- '''
- :return: A tuple ``(H, W)``, where ``H`` is the height of the data and ``W` is the weight of the data.
- For example, this function returns ``(128, 128)`` for the DVS128 Gesture dataset.
- :rtype: tuple
- '''
- return 180, 240
-
- @staticmethod
- def read_mat_save_to_np(mat_file: str, np_file: str):
- events = ASLDVS.load_origin_data(mat_file)
- np_savez(np_file,
- t=events['t'],
- x=events['x'],
- y=events['y'],
- p=events['p']
- )
- print(f'Save [{mat_file}] to [{np_file}].')
-
-
- @staticmethod
- def create_events_np_files(extract_root: str, events_np_root: str):
- '''
- :param extract_root: Root directory path which saves extracted files from downloaded files
- :type extract_root: str
- :param events_np_root: Root directory path which saves events files in the ``npz`` format
- :type events_np_root:
- :return: None
-
- This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``.
- '''
- t_ckp = time.time()
- with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess)) as tpe:
- sub_threads = []
- for class_name in os.listdir(extract_root):
- mat_dir = os.path.join(extract_root, class_name)
- np_dir = os.path.join(events_np_root, class_name)
- os.mkdir(np_dir)
- print(f'Mkdir [{np_dir}].')
- for bin_file in os.listdir(mat_dir):
- source_file = os.path.join(mat_dir, bin_file)
- target_file = os.path.join(np_dir, os.path.splitext(bin_file)[0] + '.npz')
- print(f'Start to convert [{source_file}] to [{target_file}].')
- sub_threads.append(tpe.submit(ASLDVS.read_mat_save_to_np, source_file,
- target_file))
- for sub_thread in sub_threads:
- if sub_thread.exception():
- print(sub_thread.exception())
- exit(-1)
-
-
-
- print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')
|