|
- import sys
- from pathlib import Path
- from unittest.mock import MagicMock, patch
-
- import pytest
-
- import mmcv
- from mmcv import BaseStorageBackend, FileClient
-
- sys.modules['ceph'] = MagicMock()
- sys.modules['petrel_client'] = MagicMock()
- sys.modules['petrel_client.client'] = MagicMock()
- sys.modules['mc'] = MagicMock()
-
-
- class MockS3Client:
-
- def __init__(self, enable_mc=True):
- self.enable_mc = enable_mc
-
- def Get(self, filepath):
- with open(filepath, 'rb') as f:
- content = f.read()
- return content
-
-
- class MockMemcachedClient:
-
- def __init__(self, server_list_cfg, client_cfg):
- pass
-
- def Get(self, filepath, buffer):
- with open(filepath, 'rb') as f:
- buffer.content = f.read()
-
-
- class TestFileClient:
-
- @classmethod
- def setup_class(cls):
- cls.test_data_dir = Path(__file__).parent / 'data'
- cls.img_path = cls.test_data_dir / 'color.jpg'
- cls.img_shape = (300, 400, 3)
- cls.text_path = cls.test_data_dir / 'filelist.txt'
-
- def test_error(self):
- with pytest.raises(ValueError):
- FileClient('hadoop')
-
- def test_disk_backend(self):
- disk_backend = FileClient('disk')
-
- # input path is Path object
- img_bytes = disk_backend.get(self.img_path)
- img = mmcv.imfrombytes(img_bytes)
- assert self.img_path.open('rb').read() == img_bytes
- assert img.shape == self.img_shape
- # input path is str
- img_bytes = disk_backend.get(str(self.img_path))
- img = mmcv.imfrombytes(img_bytes)
- assert self.img_path.open('rb').read() == img_bytes
- assert img.shape == self.img_shape
-
- # input path is Path object
- value_buf = disk_backend.get_text(self.text_path)
- assert self.text_path.open('r').read() == value_buf
- # input path is str
- value_buf = disk_backend.get_text(str(self.text_path))
- assert self.text_path.open('r').read() == value_buf
-
- @patch('ceph.S3Client', MockS3Client)
- def test_ceph_backend(self):
- with pytest.warns(
- Warning, match='Ceph is deprecate in favor of Petrel.'):
- FileClient('ceph')
- ceph_backend = FileClient('ceph')
-
- # input path is Path object
- with pytest.raises(NotImplementedError):
- ceph_backend.get_text(self.text_path)
- # input path is str
- with pytest.raises(NotImplementedError):
- ceph_backend.get_text(str(self.text_path))
-
- # input path is Path object
- img_bytes = ceph_backend.get(self.img_path)
- img = mmcv.imfrombytes(img_bytes)
- assert img.shape == self.img_shape
- # input path is str
- img_bytes = ceph_backend.get(str(self.img_path))
- img = mmcv.imfrombytes(img_bytes)
- assert img.shape == self.img_shape
-
- # `path_mapping` is either None or dict
- with pytest.raises(AssertionError):
- FileClient('ceph', path_mapping=1)
- # test `path_mapping`
- ceph_path = 's3://user/data'
- ceph_backend = FileClient(
- 'ceph', path_mapping={str(self.test_data_dir): ceph_path})
- ceph_backend.client._client.Get = MagicMock(
- return_value=ceph_backend.client._client.Get(self.img_path))
- img_bytes = ceph_backend.get(self.img_path)
- img = mmcv.imfrombytes(img_bytes)
- assert img.shape == self.img_shape
- ceph_backend.client._client.Get.assert_called_with(
- str(self.img_path).replace(str(self.test_data_dir), ceph_path))
-
- @patch('petrel_client.client.Client', MockS3Client)
- def test_petrel_backend(self):
- petrel_backend = FileClient('petrel')
-
- # input path is Path object
- with pytest.raises(NotImplementedError):
- petrel_backend.get_text(self.text_path)
- # input path is str
- with pytest.raises(NotImplementedError):
- petrel_backend.get_text(str(self.text_path))
-
- # input path is Path object
- img_bytes = petrel_backend.get(self.img_path)
- img = mmcv.imfrombytes(img_bytes)
- assert img.shape == self.img_shape
- # input path is str
- img_bytes = petrel_backend.get(str(self.img_path))
- img = mmcv.imfrombytes(img_bytes)
- assert img.shape == self.img_shape
-
- # `path_mapping` is either None or dict
- with pytest.raises(AssertionError):
- FileClient('petrel', path_mapping=1)
- # test `path_mapping`
- petrel_path = 's3://user/data'
- petrel_backend = FileClient(
- 'petrel', path_mapping={str(self.test_data_dir): petrel_path})
- petrel_backend.client._client.Get = MagicMock(
- return_value=petrel_backend.client._client.Get(self.img_path))
- img_bytes = petrel_backend.get(self.img_path)
- img = mmcv.imfrombytes(img_bytes)
- assert img.shape == self.img_shape
- petrel_backend.client._client.Get.assert_called_with(
- str(self.img_path).replace(str(self.test_data_dir), petrel_path))
-
- @patch('mc.MemcachedClient.GetInstance', MockMemcachedClient)
- @patch('mc.pyvector', MagicMock)
- @patch('mc.ConvertBuffer', lambda x: x.content)
- def test_memcached_backend(self):
- mc_cfg = dict(server_list_cfg='', client_cfg='', sys_path=None)
- mc_backend = FileClient('memcached', **mc_cfg)
-
- # input path is Path object
- with pytest.raises(NotImplementedError):
- mc_backend.get_text(self.text_path)
- # input path is str
- with pytest.raises(NotImplementedError):
- mc_backend.get_text(str(self.text_path))
-
- # input path is Path object
- img_bytes = mc_backend.get(self.img_path)
- img = mmcv.imfrombytes(img_bytes)
- assert img.shape == self.img_shape
- # input path is str
- img_bytes = mc_backend.get(str(self.img_path))
- img = mmcv.imfrombytes(img_bytes)
- assert img.shape == self.img_shape
-
- def test_lmdb_backend(self):
- lmdb_path = self.test_data_dir / 'demo.lmdb'
-
- # db_path is Path object
- lmdb_backend = FileClient('lmdb', db_path=lmdb_path)
-
- with pytest.raises(NotImplementedError):
- lmdb_backend.get_text(self.text_path)
-
- img_bytes = lmdb_backend.get('baboon')
- img = mmcv.imfrombytes(img_bytes)
- assert img.shape == (120, 125, 3)
-
- # db_path is str
- lmdb_backend = FileClient('lmdb', db_path=str(lmdb_path))
- with pytest.raises(NotImplementedError):
- lmdb_backend.get_text(str(self.text_path))
- img_bytes = lmdb_backend.get('baboon')
- img = mmcv.imfrombytes(img_bytes)
- assert img.shape == (120, 125, 3)
-
- def test_register_backend(self):
-
- # name must be a string
- with pytest.raises(TypeError):
-
- class TestClass1:
- pass
-
- FileClient.register_backend(1, TestClass1)
-
- # module must be a class
- with pytest.raises(TypeError):
- FileClient.register_backend('int', 0)
-
- # module must be a subclass of BaseStorageBackend
- with pytest.raises(TypeError):
-
- class TestClass1:
- pass
-
- FileClient.register_backend('TestClass1', TestClass1)
-
- class ExampleBackend(BaseStorageBackend):
-
- def get(self, filepath):
- return filepath
-
- def get_text(self, filepath):
- return filepath
-
- FileClient.register_backend('example', ExampleBackend)
- example_backend = FileClient('example')
- assert example_backend.get(self.img_path) == self.img_path
- assert example_backend.get_text(self.text_path) == self.text_path
- assert 'example' in FileClient._backends
-
- class Example2Backend(BaseStorageBackend):
-
- def get(self, filepath):
- return 'bytes2'
-
- def get_text(self, filepath):
- return 'text2'
-
- # force=False
- with pytest.raises(KeyError):
- FileClient.register_backend('example', Example2Backend)
-
- FileClient.register_backend('example', Example2Backend, force=True)
- example_backend = FileClient('example')
- assert example_backend.get(self.img_path) == 'bytes2'
- assert example_backend.get_text(self.text_path) == 'text2'
-
- @FileClient.register_backend(name='example3')
- class Example3Backend(BaseStorageBackend):
-
- def get(self, filepath):
- return 'bytes3'
-
- def get_text(self, filepath):
- return 'text3'
-
- example_backend = FileClient('example3')
- assert example_backend.get(self.img_path) == 'bytes3'
- assert example_backend.get_text(self.text_path) == 'text3'
- assert 'example3' in FileClient._backends
-
- # force=False
- with pytest.raises(KeyError):
-
- @FileClient.register_backend(name='example3')
- class Example4Backend(BaseStorageBackend):
-
- def get(self, filepath):
- return 'bytes4'
-
- def get_text(self, filepath):
- return 'text4'
-
- @FileClient.register_backend(name='example3', force=True)
- class Example5Backend(BaseStorageBackend):
-
- def get(self, filepath):
- return 'bytes5'
-
- def get_text(self, filepath):
- return 'text5'
-
- example_backend = FileClient('example3')
- assert example_backend.get(self.img_path) == 'bytes5'
- assert example_backend.get_text(self.text_path) == 'text5'
|