|
- #! /usr/bin/python
- # -*- coding: utf-8 -*-
-
- import base64
- import datetime
- import gzip
- import json
- import math
- import os
- import pickle
- import re
- import shutil
- # import ast
- import sys
- import tarfile
- import time
- import zipfile
-
- import cloudpickle
- import h5py
- import numpy as np
- import progressbar
- import scipy.io as sio
- from six.moves import cPickle
- import tensorlayerx as tlx
- from tensorlayerx import logging
- from tensorlayerx.utils import visualize
-
- if tlx.BACKEND == 'tensorflow':
- from tensorflow.python.keras.saving import model_config as model_config_lib
- from tensorflow.python.platform import gfile
- from tensorflow.python.util import serialization
- from tensorflow.python.util.tf_export import keras_export
- from tensorflow.python import pywrap_tensorflow
- import tensorflow as tf
-
- @keras_export('keras.model.save_model')
- def save_keras_model(model):
- # f.attrs['keras_model_config'] = json.dumps(
- # {
- # 'class_name': model.__class__.__name__,
- # 'config': model.get_config()
- # },
- # default=serialization.get_json_type).encode('utf8')
- #
- # f.flush()
-
- return json.dumps(
- {
- 'class_name': model.__class__.__name__,
- 'config': model.get_config()
- }, default=serialization.get_json_type
- ).encode('utf8')
-
- @keras_export('keras.model.load_model')
- def load_keras_model(model_config):
-
- custom_objects = {}
-
- if model_config is None:
- raise ValueError('No model found in config.')
- model_config = json.loads(model_config.decode('utf-8'))
- model = model_config_lib.model_from_config(model_config, custom_objects=custom_objects)
-
- return model
-
-
- if tlx.BACKEND == 'mindspore':
- from mindspore.ops.operations import Assign
- from mindspore.nn import Cell
- from mindspore import Tensor
- import mindspore as ms
- if tlx.BACKEND == 'paddle':
- import paddle as pd
- if tlx.BACKEND == 'torch':
- import torch
-
- if sys.version_info[0] == 2:
- from urllib import urlretrieve
- else:
- from urllib.request import urlretrieve
-
- # import tensorflow.contrib.eager.python.saver as tfes
- # TODO: tf2.0 not stable, cannot import tensorflow.contrib.eager.python.saver
-
- __all__ = [
- 'assign_weights',
- 'del_file',
- 'del_folder',
- 'download_file_from_google_drive',
- 'exists_or_mkdir',
- 'file_exists',
- 'folder_exists',
- 'load_and_assign_npz',
- 'load_and_assign_npz_dict',
- 'load_ckpt',
- 'load_cropped_svhn',
- 'load_file_list',
- 'load_folder_list',
- 'load_npy_to_any',
- 'load_npz',
- 'maybe_download_and_extract',
- 'natural_keys',
- 'npz_to_W_pdf',
- 'read_file',
- 'save_any_to_npy',
- 'save_ckpt',
- 'save_npz',
- 'save_npz_dict',
- 'tf_variables_to_numpy',
- 'ms_variables_to_numpy',
- 'assign_tf_variable',
- 'assign_ms_variable',
- 'assign_pd_variable',
- 'save_weights_to_hdf5',
- 'load_hdf5_to_weights_in_order',
- 'load_hdf5_to_weights',
- 'save_hdf5_graph',
- 'load_hdf5_graph',
- 'load_and_assign_ckpt',
- 'ckpt_to_npz_dict'
- ]
-
-
- def func2str(expr):
- b = cloudpickle.dumps(expr)
- s = base64.b64encode(b).decode()
- return s
-
-
- def str2func(s):
- b = base64.b64decode(s)
- expr = cloudpickle.loads(b)
- return expr
-
-
- def save_hdf5_graph(network, filepath='model.hdf5', save_weights=False, customized_data=None):
- """Save the architecture of TL model into a hdf5 file. Support saving model weights.
-
- Parameters
- -----------
- network : TensorLayer Model.
- The network to save.
- filepath : str
- The name of model file.
- save_weights : bool
- Whether to save model weights.
- customized_data : dict
- The user customized meta data.
-
- Examples
- --------
- >>> # Save the architecture (with parameters)
- >>> tlx.files.save_hdf5_graph(network, filepath='model.hdf5', save_weights=True)
- >>> # Save the architecture (without parameters)
- >>> tlx.files.save_hdf5_graph(network, filepath='model.hdf5', save_weights=False)
- >>> # Load the architecture in another script (no parameters restore)
- >>> net = tlx.files.load_hdf5_graph(filepath='model.hdf5', load_weights=False)
- >>> # Load the architecture in another script (restore parameters)
- >>> net = tlx.files.load_hdf5_graph(filepath='model.hdf5', load_weights=True)
- """
- if network.outputs is None:
- raise RuntimeError("save_hdf5_graph not support dynamic mode yet")
-
- logging.info("[*] Saving TL model into {}, saving weights={}".format(filepath, save_weights))
-
- model_config = network.config # net2static_graph(network)
- model_config["version_info"]["save_date"] = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc
- ).isoformat()
- model_config_str = str(model_config)
- customized_data_str = str(customized_data)
-
- with h5py.File(filepath, 'w') as f:
- f.attrs["model_config"] = model_config_str.encode('utf8')
- f.attrs["customized_data"] = customized_data_str.encode('utf8')
- if save_weights:
- _save_weights_to_hdf5_group(f, network.all_layers)
- f.flush()
-
- logging.info("[*] Saved TL model into {}, saving weights={}".format(filepath, save_weights))
-
-
- def generate_func(args):
- for key in args:
- if isinstance(args[key], tuple) and args[key][0] == 'is_Func':
- fn = str2func(args[key][1])
- args[key] = fn
- # if key in ['act']:
- # # fn_dict = args[key]
- # # module_path = fn_dict['module_path']
- # # func_name = fn_dict['func_name']
- # # lib = importlib.import_module(module_path)
- # # fn = getattr(lib, func_name)
- # # args[key] = fn
- # fn = str2func(args[key])
- # args[key] = fn
- # elif key in ['fn']:
- # fn = str2func(args[key])
- # args[key] = fn
-
- def load_hdf5_graph(filepath='model.hdf5', load_weights=False):
- """Restore TL model archtecture from a a pickle file. Support loading model weights.
-
- Parameters
- -----------
- filepath : str
- The name of model file.
- load_weights : bool
- Whether to load model weights.
-
- Returns
- --------
- network : TensorLayer Model.
-
- Examples
- --------
- - see ``tlx.files.save_hdf5_graph``
- """
- logging.info("[*] Loading TL model from {}, loading weights={}".format(filepath, load_weights))
-
- f = h5py.File(filepath, 'r')
-
- model_config_str = f.attrs["model_config"].decode('utf8')
- model_config = eval(model_config_str)
-
- # version_info_str = f.attrs["version_info"].decode('utf8')
- # version_info = eval(version_info_str)
- version_info = model_config["version_info"]
- backend_version = version_info["backend_version"]
- tensorlayerx_version = version_info["tensorlayerx_version"]
- if backend_version != tf.__version__:
- logging.warning(
- "Saved model uses tensorflow version {}, but now you are using tensorflow version {}".format(
- backend_version, tf.__version__
- )
- )
- if tensorlayerx_version != tlx.__version__:
- logging.warning(
- "Saved model uses tensorlayerx version {}, but now you are using tensorlayerx version {}".format(
- tensorlayerx_version, tlx.__version__
- )
- )
-
- M = static_graph2net(model_config)
- if load_weights:
- if not ('layer_names' in f.attrs.keys()):
- raise RuntimeError("Saved model does not contain weights.")
- M.load_weights(filepath=filepath)
-
- f.close()
-
- logging.info("[*] Loaded TL model from {}, loading weights={}".format(filepath, load_weights))
-
- return M
-
-
- def load_mnist_dataset(shape=(-1, 784), path='data'):
- """Load the original mnist.
-
- Automatically download MNIST dataset and return the training, validation and test set with 50000, 10000 and 10000 digit images respectively.
-
- Parameters
- ----------
- shape : tuple
- The shape of digit images (the default is (-1, 784), alternatively (-1, 28, 28, 1)).
- path : str
- The path that the data is downloaded to.
-
- Returns
- -------
- X_train, y_train, X_val, y_val, X_test, y_test: tuple
- Return splitted training/validation/test set respectively.
-
- Examples
- --------
- >>> X_train, y_train, X_val, y_val, X_test, y_test = tlx.files.load_mnist_dataset(shape=(-1,784), path='datasets')
- >>> X_train, y_train, X_val, y_val, X_test, y_test = tlx.files.load_mnist_dataset(shape=(-1, 28, 28, 1))
- """
- return _load_mnist_dataset(shape, path, name='mnist', url='http://yann.lecun.com/exdb/mnist/')
-
-
- def load_fashion_mnist_dataset(shape=(-1, 784), path='data'):
- """Load the fashion mnist.
-
- Automatically download fashion-MNIST dataset and return the training, validation and test set with 50000, 10000 and 10000 fashion images respectively, `examples <http://marubon-ds.blogspot.co.uk/2017/09/fashion-mnist-exploring.html>`__.
-
- Parameters
- ----------
- shape : tuple
- The shape of digit images (the default is (-1, 784), alternatively (-1, 28, 28, 1)).
- path : str
- The path that the data is downloaded to.
-
- Returns
- -------
- X_train, y_train, X_val, y_val, X_test, y_test: tuple
- Return splitted training/validation/test set respectively.
-
- Examples
- --------
- >>> X_train, y_train, X_val, y_val, X_test, y_test = tlx.files.load_fashion_mnist_dataset(shape=(-1,784), path='datasets')
- >>> X_train, y_train, X_val, y_val, X_test, y_test = tlx.files.load_fashion_mnist_dataset(shape=(-1, 28, 28, 1))
- """
- return _load_mnist_dataset(
- shape, path, name='fashion_mnist', url='http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
- )
-
-
- def _load_mnist_dataset(shape, path, name='mnist', url='http://yann.lecun.com/exdb/mnist/'):
- """A generic function to load mnist-like dataset.
-
- Parameters:
- ----------
- shape : tuple
- The shape of digit images.
- path : str
- The path that the data is downloaded to.
- name : str
- The dataset name you want to use(the default is 'mnist').
- url : str
- The url of dataset(the default is 'http://yann.lecun.com/exdb/mnist/').
- """
- path = os.path.join(path, name)
-
- # Define functions for loading mnist-like data's images and labels.
- # For convenience, they also download the requested files if needed.
- def load_mnist_images(path, filename):
- filepath = maybe_download_and_extract(filename, path, url)
-
- logging.info(filepath)
- # Read the inputs in Yann LeCun's binary format.
- with gzip.open(filepath, 'rb') as f:
- data = np.frombuffer(f.read(), np.uint8, offset=16)
- # The inputs are vectors now, we reshape them to monochrome 2D images,
- # following the shape convention: (examples, channels, rows, columns)
- data = data.reshape(shape)
- # The inputs come as bytes, we convert them to float32 in range [0,1].
- # (Actually to range [0, 255/256], for compatibility to the version
- # provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.)
- return data / np.float32(256)
-
- def load_mnist_labels(path, filename):
- filepath = maybe_download_and_extract(filename, path, url)
- # Read the labels in Yann LeCun's binary format.
- with gzip.open(filepath, 'rb') as f:
- data = np.frombuffer(f.read(), np.uint8, offset=8)
- # The labels are vectors of integers now, that's exactly what we want.
- return data
-
- # Download and read the training and test set images and labels.
- logging.info("Load or Download {0} > {1}".format(name.upper(), path))
- X_train = load_mnist_images(path, 'train-images-idx3-ubyte.gz')
- y_train = load_mnist_labels(path, 'train-labels-idx1-ubyte.gz')
- X_test = load_mnist_images(path, 't10k-images-idx3-ubyte.gz')
- y_test = load_mnist_labels(path, 't10k-labels-idx1-ubyte.gz')
-
- # We reserve the last 10000 training examples for validation.
- X_train, X_val = X_train[:-10000], X_train[-10000:]
- y_train, y_val = y_train[:-10000], y_train[-10000:]
-
- # We just return all the arrays in order, as expected in main().
- # (It doesn't matter how we do this as long as we can read them again.)
- X_train = np.asarray(X_train, dtype=np.float32)
- y_train = np.asarray(y_train, dtype=np.int32)
- X_val = np.asarray(X_val, dtype=np.float32)
- y_val = np.asarray(y_val, dtype=np.int32)
- X_test = np.asarray(X_test, dtype=np.float32)
- y_test = np.asarray(y_test, dtype=np.int32)
- return X_train, y_train, X_val, y_val, X_test, y_test
-
-
- def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='data', plotable=False):
- """Load CIFAR-10 dataset.
-
- It consists of 60000 32x32 colour images in 10 classes, with
- 6000 images per class. There are 50000 training images and 10000 test images.
-
- The dataset is divided into five training batches and one test batch, each with
- 10000 images. The test batch contains exactly 1000 randomly-selected images from
- each class. The training batches contain the remaining images in random order,
- but some training batches may contain more images from one class than another.
- Between them, the training batches contain exactly 5000 images from each class.
-
- Parameters
- ----------
- shape : tupe
- The shape of digit images e.g. (-1, 3, 32, 32) and (-1, 32, 32, 3).
- path : str
- The path that the data is downloaded to, defaults is ``data/cifar10/``.
- plotable : boolean
- Whether to plot some image examples, False as default.
-
- Examples
- --------
- >>> X_train, y_train, X_test, y_test = tlx.files.load_cifar10_dataset(shape=(-1, 32, 32, 3))
-
- References
- ----------
- - `CIFAR website <https://www.cs.toronto.edu/~kriz/cifar.html>`__
- - `Data download link <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`__
- - `<https://teratail.com/questions/28932>`__
-
- """
- path = os.path.join(path, 'cifar10')
- logging.info("Load or Download cifar10 > {}".format(path))
-
- # Helper function to unpickle the data
- def unpickle(file):
- fp = open(file, 'rb')
- if sys.version_info.major == 2:
- data = pickle.load(fp)
- elif sys.version_info.major == 3:
- data = pickle.load(fp, encoding='latin-1')
- fp.close()
- return data
-
- filename = 'cifar-10-python.tar.gz'
- url = 'https://www.cs.toronto.edu/~kriz/'
- # Download and uncompress file
- maybe_download_and_extract(filename, path, url, extract=True)
-
- # Unpickle file and fill in data
- X_train = None
- y_train = []
- for i in range(1, 6):
- data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "data_batch_{}".format(i)))
- if i == 1:
- X_train = data_dic['data']
- else:
- X_train = np.vstack((X_train, data_dic['data']))
- y_train += data_dic['labels']
-
- test_data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "test_batch"))
- X_test = test_data_dic['data']
- y_test = np.array(test_data_dic['labels'])
-
- if shape == (-1, 3, 32, 32):
- X_test = X_test.reshape(shape)
- X_train = X_train.reshape(shape)
- elif shape == (-1, 32, 32, 3):
- X_test = X_test.reshape(shape, order='F')
- X_train = X_train.reshape(shape, order='F')
- X_test = np.transpose(X_test, (0, 2, 1, 3))
- X_train = np.transpose(X_train, (0, 2, 1, 3))
- else:
- X_test = X_test.reshape(shape)
- X_train = X_train.reshape(shape)
-
- y_train = np.array(y_train)
-
- if plotable:
-
- if sys.platform.startswith('darwin'):
- import matplotlib
- matplotlib.use('TkAgg')
- import matplotlib.pyplot as plt
-
- logging.info('\nCIFAR-10')
- fig = plt.figure(1)
-
- logging.info('Shape of a training image: X_train[0] %s' % X_train[0].shape)
-
- plt.ion() # interactive mode
- count = 1
- for _ in range(10): # each row
- for _ in range(10): # each column
- _ = fig.add_subplot(10, 10, count)
- if shape == (-1, 3, 32, 32):
- # plt.imshow(X_train[count-1], interpolation='nearest')
- plt.imshow(np.transpose(X_train[count - 1], (1, 2, 0)), interpolation='nearest')
- # plt.imshow(np.transpose(X_train[count-1], (2, 1, 0)), interpolation='nearest')
- elif shape == (-1, 32, 32, 3):
- plt.imshow(X_train[count - 1], interpolation='nearest')
- # plt.imshow(np.transpose(X_train[count-1], (1, 0, 2)), interpolation='nearest')
- else:
- raise Exception("Do not support the given 'shape' to plot the image examples")
- plt.gca().xaxis.set_major_locator(plt.NullLocator()) # 不显示刻度(tick)
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
- count = count + 1
- plt.draw() # interactive mode
- plt.pause(3) # interactive mode
-
- logging.info("X_train: %s" % X_train.shape)
- logging.info("y_train: %s" % y_train.shape)
- logging.info("X_test: %s" % X_test.shape)
- logging.info("y_test: %s" % y_test.shape)
-
- X_train = np.asarray(X_train, dtype=np.float32)
- X_test = np.asarray(X_test, dtype=np.float32)
- y_train = np.asarray(y_train, dtype=np.int32)
- y_test = np.asarray(y_test, dtype=np.int32)
-
- return X_train, y_train, X_test, y_test
-
-
- def load_cropped_svhn(path='data', include_extra=True):
- """Load Cropped SVHN.
-
- The Cropped Street View House Numbers (SVHN) Dataset contains 32x32x3 RGB images.
- Digit '1' has label 1, '9' has label 9 and '0' has label 0 (the original dataset uses 10 to represent '0'), see `ufldl website <http://ufldl.stanford.edu/housenumbers/>`__.
-
- Parameters
- ----------
- path : str
- The path that the data is downloaded to.
- include_extra : boolean
- If True (default), add extra images to the training set.
-
- Returns
- -------
- X_train, y_train, X_test, y_test: tuple
- Return splitted training/test set respectively.
-
- Examples
- ---------
- >>> X_train, y_train, X_test, y_test = tlx.files.load_cropped_svhn(include_extra=False)
- >>> tlx.vis.save_images(X_train[0:100], [10, 10], 'svhn.png')
-
- """
- start_time = time.time()
-
- path = os.path.join(path, 'cropped_svhn')
- logging.info("Load or Download Cropped SVHN > {} | include extra images: {}".format(path, include_extra))
- url = "http://ufldl.stanford.edu/housenumbers/"
-
- np_file = os.path.join(path, "train_32x32.npz")
- if file_exists(np_file) is False:
- filename = "train_32x32.mat"
- filepath = maybe_download_and_extract(filename, path, url)
- mat = sio.loadmat(filepath)
- X_train = mat['X'] / 255.0 # to [0, 1]
- X_train = np.transpose(X_train, (3, 0, 1, 2))
- y_train = np.squeeze(mat['y'], axis=1)
- y_train[y_train == 10] = 0 # replace 10 to 0
- np.savez(np_file, X=X_train, y=y_train)
- del_file(filepath)
- else:
- v = np.load(np_file, allow_pickle=True)
- X_train = v['X']
- y_train = v['y']
- logging.info(" n_train: {}".format(len(y_train)))
-
- np_file = os.path.join(path, "test_32x32.npz")
- if file_exists(np_file) is False:
- filename = "test_32x32.mat"
- filepath = maybe_download_and_extract(filename, path, url)
- mat = sio.loadmat(filepath)
- X_test = mat['X'] / 255.0
- X_test = np.transpose(X_test, (3, 0, 1, 2))
- y_test = np.squeeze(mat['y'], axis=1)
- y_test[y_test == 10] = 0
- np.savez(np_file, X=X_test, y=y_test)
- del_file(filepath)
- else:
- v = np.load(np_file, allow_pickle=True)
- X_test = v['X']
- y_test = v['y']
- logging.info(" n_test: {}".format(len(y_test)))
-
- if include_extra:
- logging.info(" getting extra 531131 images, please wait ...")
- np_file = os.path.join(path, "extra_32x32.npz")
- if file_exists(np_file) is False:
- logging.info(" the first time to load extra images will take long time to convert the file format ...")
- filename = "extra_32x32.mat"
- filepath = maybe_download_and_extract(filename, path, url)
- mat = sio.loadmat(filepath)
- X_extra = mat['X'] / 255.0
- X_extra = np.transpose(X_extra, (3, 0, 1, 2))
- y_extra = np.squeeze(mat['y'], axis=1)
- y_extra[y_extra == 10] = 0
- np.savez(np_file, X=X_extra, y=y_extra)
- del_file(filepath)
- else:
- v = np.load(np_file, allow_pickle=True)
- X_extra = v['X']
- y_extra = v['y']
- logging.info(" adding n_extra {} to n_train {}".format(len(y_extra), len(y_train)))
- t = time.time()
- X_train = np.concatenate((X_train, X_extra), 0)
- y_train = np.concatenate((y_train, y_extra), 0)
- # X_train = np.append(X_train, X_extra, axis=0)
- # y_train = np.append(y_train, y_extra, axis=0)
- logging.info(" added n_extra {} to n_train {} took {}s".format(len(y_extra), len(y_train), time.time() - t))
- else:
- logging.info(" no extra images are included")
- logging.info(" image size: %s n_train: %d n_test: %d" % (str(X_train.shape[1:4]), len(y_train), len(y_test)))
- logging.info(" took: {}s".format(int(time.time() - start_time)))
- return X_train, y_train, X_test, y_test
-
-
- # def load_ptb_dataset(path='data'):
- # """Load Penn TreeBank (PTB) dataset.
- #
- # It is used in many LANGUAGE MODELING papers,
- # including "Empirical Evaluation and Combination of Advanced Language
- # Modeling Techniques", "Recurrent Neural Network Regularization".
- # It consists of 929k training words, 73k validation words, and 82k test
- # words. It has 10k words in its vocabulary.
- #
- # Parameters
- # ----------
- # path : str
- # The path that the data is downloaded to, defaults is ``data/ptb/``.
- #
- # Returns
- # --------
- # train_data, valid_data, test_data : list of int
- # The training, validating and testing data in integer format.
- # vocab_size : int
- # The vocabulary size.
- #
- # Examples
- # --------
- # >>> train_data, valid_data, test_data, vocab_size = tlx.files.load_ptb_dataset()
- #
- # References
- # ---------------
- # - ``tensorflow.model.rnn.ptb import reader``
- # - `Manual download <http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz>`__
- #
- # Notes
- # ------
- # - If you want to get the raw data, see the source code.
- #
- # """
- # path = os.path.join(path, 'ptb')
- # logging.info("Load or Download Penn TreeBank (PTB) dataset > {}".format(path))
- #
- # # Maybe dowload and uncompress tar, or load exsisting files
- # filename = 'simple-examples.tgz'
- # url = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/'
- # maybe_download_and_extract(filename, path, url, extract=True)
- #
- # data_path = os.path.join(path, 'simple-examples', 'data')
- # train_path = os.path.join(data_path, "ptb.train.txt")
- # valid_path = os.path.join(data_path, "ptb.valid.txt")
- # test_path = os.path.join(data_path, "ptb.test.txt")
- #
- # word_to_id = nlp.build_vocab(nlp.read_words(train_path))
- #
- # train_data = nlp.words_to_word_ids(nlp.read_words(train_path), word_to_id)
- # valid_data = nlp.words_to_word_ids(nlp.read_words(valid_path), word_to_id)
- # test_data = nlp.words_to_word_ids(nlp.read_words(test_path), word_to_id)
- # vocab_size = len(word_to_id)
- #
- # # logging.info(nlp.read_words(train_path)) # ... 'according', 'to', 'mr.', '<unk>', '<eos>']
- # # logging.info(train_data) # ... 214, 5, 23, 1, 2]
- # # logging.info(word_to_id) # ... 'beyond': 1295, 'anti-nuclear': 9599, 'trouble': 1520, '<eos>': 2 ... }
- # # logging.info(vocabulary) # 10000
- # # exit()
- # return train_data, valid_data, test_data, vocab_size
-
-
- def load_matt_mahoney_text8_dataset(path='data'):
- """Load Matt Mahoney's dataset.
-
- Download a text file from Matt Mahoney's website
- if not present, and make sure it's the right size.
- Extract the first file enclosed in a zip file as a list of words.
- This dataset can be used for Word Embedding.
-
- Parameters
- ----------
- path : str
- The path that the data is downloaded to, defaults is ``data/mm_test8/``.
-
- Returns
- --------
- list of str
- The raw text data e.g. [.... 'their', 'families', 'who', 'were', 'expelled', 'from', 'jerusalem', ...]
-
- Examples
- --------
- >>> words = tlx.files.load_matt_mahoney_text8_dataset()
- >>> print('Data size', len(words))
-
- """
- path = os.path.join(path, 'mm_test8')
- logging.info("Load or Download matt_mahoney_text8 Dataset> {}".format(path))
-
- filename = 'text8.zip'
- url = 'http://mattmahoney.net/dc/'
- maybe_download_and_extract(filename, path, url, expected_bytes=31344016)
-
- with zipfile.ZipFile(os.path.join(path, filename)) as f:
- word_list = f.read(f.namelist()[0]).split()
- for idx, _ in enumerate(word_list):
- word_list[idx] = word_list[idx].decode()
- return word_list
-
-
- def load_imdb_dataset(
- path='data', nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113, start_char=1, oov_char=2,
- index_from=3
- ):
- """Load IMDB dataset.
-
- Parameters
- ----------
- path : str
- The path that the data is downloaded to, defaults is ``data/imdb/``.
- nb_words : int
- Number of words to get.
- skip_top : int
- Top most frequent words to ignore (they will appear as oov_char value in the sequence data).
- maxlen : int
- Maximum sequence length. Any longer sequence will be truncated.
- seed : int
- Seed for reproducible data shuffling.
- start_char : int
- The start of a sequence will be marked with this character. Set to 1 because 0 is usually the padding character.
- oov_char : int
- Words that were cut out because of the num_words or skip_top limit will be replaced with this character.
- index_from : int
- Index actual words with this index and higher.
-
- Examples
- --------
- >>> X_train, y_train, X_test, y_test = tlx.files.load_imdb_dataset(
- ... nb_words=20000, test_split=0.2)
- >>> print('X_train.shape', X_train.shape)
- (20000,) [[1, 62, 74, ... 1033, 507, 27],[1, 60, 33, ... 13, 1053, 7]..]
- >>> print('y_train.shape', y_train.shape)
- (20000,) [1 0 0 ..., 1 0 1]
-
- References
- -----------
- - `Modified from keras. <https://github.com/fchollet/keras/blob/master/keras/datasets/imdb.py>`__
-
- """
- path = os.path.join(path, 'imdb')
-
- filename = "imdb.pkl"
- url = 'https://s3.amazonaws.com/text-datasets/'
- maybe_download_and_extract(filename, path, url)
-
- if filename.endswith(".gz"):
- f = gzip.open(os.path.join(path, filename), 'rb')
- else:
- f = open(os.path.join(path, filename), 'rb')
-
- X, labels = cPickle.load(f)
- f.close()
-
- np.random.seed(seed)
- np.random.shuffle(X)
- np.random.seed(seed)
- np.random.shuffle(labels)
-
- if start_char is not None:
- X = [[start_char] + [w + index_from for w in x] for x in X]
- elif index_from:
- X = [[w + index_from for w in x] for x in X]
-
- if maxlen:
- new_X = []
- new_labels = []
- for x, y in zip(X, labels):
- if len(x) < maxlen:
- new_X.append(x)
- new_labels.append(y)
- X = new_X
- labels = new_labels
- if not X:
- raise Exception(
- 'After filtering for sequences shorter than maxlen=' + str(maxlen) + ', no sequence was kept. '
- 'Increase maxlen.'
- )
- if not nb_words:
- nb_words = max([max(x) for x in X])
-
- # by convention, use 2 as OOV word
- # reserve 'index_from' (=3 by default) characters: 0 (padding), 1 (start), 2 (OOV)
- if oov_char is not None:
- X = [[oov_char if (w >= nb_words or w < skip_top) else w for w in x] for x in X]
- else:
- nX = []
- for x in X:
- nx = []
- for w in x:
- if (w >= nb_words or w < skip_top):
- nx.append(w)
- nX.append(nx)
- X = nX
-
- X_train = np.array(X[:int(len(X) * (1 - test_split))])
- y_train = np.array(labels[:int(len(X) * (1 - test_split))])
-
- X_test = np.array(X[int(len(X) * (1 - test_split)):])
- y_test = np.array(labels[int(len(X) * (1 - test_split)):])
-
- return X_train, y_train, X_test, y_test
-
-
- def load_nietzsche_dataset(path='data'):
- """Load Nietzsche dataset.
-
- Parameters
- ----------
- path : str
- The path that the data is downloaded to, defaults is ``data/nietzsche/``.
-
- Returns
- --------
- str
- The content.
-
- Examples
- --------
- >>> see tutorial_generate_text.py
- >>> words = tlx.files.load_nietzsche_dataset()
- >>> words = basic_clean_str(words)
- >>> words = words.split()
-
- """
- logging.info("Load or Download nietzsche dataset > {}".format(path))
- path = os.path.join(path, 'nietzsche')
-
- filename = "nietzsche.txt"
- url = 'https://s3.amazonaws.com/text-datasets/'
- filepath = maybe_download_and_extract(filename, path, url)
-
- with open(filepath, "r") as f:
- words = f.read()
- return words
-
-
- def load_wmt_en_fr_dataset(path='data'):
- """Load WMT'15 English-to-French translation dataset.
-
- It will download the data from the WMT'15 Website (10^9-French-English corpus), and the 2013 news test from the same site as development set.
- Returns the directories of training data and test data.
-
- Parameters
- ----------
- path : str
- The path that the data is downloaded to, defaults is ``data/wmt_en_fr/``.
-
- References
- ----------
- - Code modified from /tensorflow/model/rnn/translation/data_utils.py
-
- Notes
- -----
- Usually, it will take a long time to download this dataset.
-
- """
- path = os.path.join(path, 'wmt_en_fr')
- # URLs for WMT data.
- _WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/"
- _WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/"
-
- def gunzip_file(gz_path, new_path):
- """Unzips from gz_path into new_path."""
- logging.info("Unpacking %s to %s" % (gz_path, new_path))
- with gzip.open(gz_path, "rb") as gz_file:
- with open(new_path, "wb") as new_file:
- for line in gz_file:
- new_file.write(line)
-
- def get_wmt_enfr_train_set(path):
- """Download the WMT en-fr training corpus to directory unless it's there."""
- filename = "training-giga-fren.tar"
- maybe_download_and_extract(filename, path, _WMT_ENFR_TRAIN_URL, extract=True)
- train_path = os.path.join(path, "giga-fren.release2.fixed")
- gunzip_file(train_path + ".fr.gz", train_path + ".fr")
- gunzip_file(train_path + ".en.gz", train_path + ".en")
- return train_path
-
- def get_wmt_enfr_dev_set(path):
- """Download the WMT en-fr training corpus to directory unless it's there."""
- filename = "dev-v2.tgz"
- dev_file = maybe_download_and_extract(filename, path, _WMT_ENFR_DEV_URL, extract=False)
- dev_name = "newstest2013"
- dev_path = os.path.join(path, "newstest2013")
- if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")):
- logging.info("Extracting tgz file %s" % dev_file)
- with tarfile.open(dev_file, "r:gz") as dev_tar:
- fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr")
- en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en")
- fr_dev_file.name = dev_name + ".fr" # Extract without "dev/" prefix.
- en_dev_file.name = dev_name + ".en"
- dev_tar.extract(fr_dev_file, path)
- dev_tar.extract(en_dev_file, path)
- return dev_path
-
- logging.info("Load or Download WMT English-to-French translation > {}".format(path))
-
- train_path = get_wmt_enfr_train_set(path)
- dev_path = get_wmt_enfr_dev_set(path)
-
- return train_path, dev_path
-
-
- def load_flickr25k_dataset(tag='sky', path="data", n_threads=50, printable=False):
- """Load Flickr25K dataset.
-
- Returns a list of images by a given tag from Flick25k dataset,
- it will download Flickr25k from `the official website <http://press.liacs.nl/mirflickr/mirdownload.html>`__
- at the first time you use it.
-
- Parameters
- ------------
- tag : str or None
- What images to return.
- - If you want to get images with tag, use string like 'dog', 'red', see `Flickr Search <https://www.flickr.com/search/>`__.
- - If you want to get all images, set to ``None``.
-
- path : str
- The path that the data is downloaded to, defaults is ``data/flickr25k/``.
- n_threads : int
- The number of thread to read image.
- printable : boolean
- Whether to print infomation when reading images, default is ``False``.
-
- Examples
- -----------
- Get images with tag of sky
-
- >>> images = tlx.files.load_flickr25k_dataset(tag='sky')
-
- Get all images
-
- >>> images = tlx.files.load_flickr25k_dataset(tag=None, n_threads=100, printable=True)
-
- """
- path = os.path.join(path, 'flickr25k')
-
- filename = 'mirflickr25k.zip'
- url = 'http://press.liacs.nl/mirflickr/mirflickr25k/'
-
- # download dataset
- if folder_exists(os.path.join(path, "mirflickr")) is False:
- logging.info("[*] Flickr25k is nonexistent in {}".format(path))
- maybe_download_and_extract(filename, path, url, extract=True)
- del_file(os.path.join(path, filename))
-
- # return images by the given tag.
- # 1. image path list
- folder_imgs = os.path.join(path, "mirflickr")
- path_imgs = load_file_list(path=folder_imgs, regx='\\.jpg', printable=False)
- path_imgs.sort(key=natural_keys)
-
- # 2. tag path list
- folder_tags = os.path.join(path, "mirflickr", "meta", "tags")
- path_tags = load_file_list(path=folder_tags, regx='\\.txt', printable=False)
- path_tags.sort(key=natural_keys)
-
- # 3. select images
- if tag is None:
- logging.info("[Flickr25k] reading all images")
- else:
- logging.info("[Flickr25k] reading images with tag: {}".format(tag))
- images_list = []
- for idx, _v in enumerate(path_tags):
- tags = read_file(os.path.join(folder_tags, path_tags[idx])).split('\n')
- # logging.info(idx+1, tags)
- if tag is None or tag in tags:
- images_list.append(path_imgs[idx])
-
- images = visualize.read_images(images_list, folder_imgs, n_threads=n_threads, printable=printable)
- return images
-
-
- def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printable=False):
- """Load Flick1M dataset.
-
- Returns a list of images by a given tag from Flickr1M dataset,
- it will download Flickr1M from `the official website <http://press.liacs.nl/mirflickr/mirdownload.html>`__
- at the first time you use it.
-
- Parameters
- ------------
- tag : str or None
- What images to return.
- - If you want to get images with tag, use string like 'dog', 'red', see `Flickr Search <https://www.flickr.com/search/>`__.
- - If you want to get all images, set to ``None``.
-
- size : int
- integer between 1 to 10. 1 means 100k images ... 5 means 500k images, 10 means all 1 million images. Default is 10.
- path : str
- The path that the data is downloaded to, defaults is ``data/flickr25k/``.
- n_threads : int
- The number of thread to read image.
- printable : boolean
- Whether to print infomation when reading images, default is ``False``.
-
- Examples
- ----------
- Use 200k images
-
- >>> images = tlx.files.load_flickr1M_dataset(tag='zebra', size=2)
-
- Use 1 Million images
-
- >>> images = tlx.files.load_flickr1M_dataset(tag='zebra')
-
- """
- path = os.path.join(path, 'flickr1M')
- logging.info("[Flickr1M] using {}% of images = {}".format(size * 10, size * 100000))
- images_zip = [
- 'images0.zip', 'images1.zip', 'images2.zip', 'images3.zip', 'images4.zip', 'images5.zip', 'images6.zip',
- 'images7.zip', 'images8.zip', 'images9.zip'
- ]
- tag_zip = 'tags.zip'
- url = 'http://press.liacs.nl/mirflickr/mirflickr1m/'
-
- # download dataset
- for image_zip in images_zip[0:size]:
- image_folder = image_zip.split(".")[0]
- # logging.info(path+"/"+image_folder)
- if folder_exists(os.path.join(path, image_folder)) is False:
- # logging.info(image_zip)
- logging.info("[Flickr1M] {} is missing in {}".format(image_folder, path))
- maybe_download_and_extract(image_zip, path, url, extract=True)
- del_file(os.path.join(path, image_zip))
- # os.system("mv {} {}".format(os.path.join(path, 'images'), os.path.join(path, image_folder)))
- shutil.move(os.path.join(path, 'images'), os.path.join(path, image_folder))
- else:
- logging.info("[Flickr1M] {} exists in {}".format(image_folder, path))
-
- # download tag
- if folder_exists(os.path.join(path, "tags")) is False:
- logging.info("[Flickr1M] tag files is nonexistent in {}".format(path))
- maybe_download_and_extract(tag_zip, path, url, extract=True)
- del_file(os.path.join(path, tag_zip))
- else:
- logging.info("[Flickr1M] tags exists in {}".format(path))
-
- # 1. image path list
- images_list = []
- images_folder_list = []
- for i in range(0, size):
- images_folder_list += load_folder_list(path=os.path.join(path, 'images%d' % i))
- images_folder_list.sort(key=lambda s: int(s.split('/')[-1])) # folder/images/ddd
-
- for folder in images_folder_list[0:size * 10]:
- tmp = load_file_list(path=folder, regx='\\.jpg', printable=False)
- tmp.sort(key=lambda s: int(s.split('.')[-2])) # ddd.jpg
- images_list.extend([os.path.join(folder, x) for x in tmp])
-
- # 2. tag path list
- tag_list = []
- tag_folder_list = load_folder_list(os.path.join(path, "tags"))
-
- # tag_folder_list.sort(key=lambda s: int(s.split("/")[-1])) # folder/images/ddd
- tag_folder_list.sort(key=lambda s: int(os.path.basename(s)))
-
- for folder in tag_folder_list[0:size * 10]:
- tmp = load_file_list(path=folder, regx='\\.txt', printable=False)
- tmp.sort(key=lambda s: int(s.split('.')[-2])) # ddd.txt
- tmp = [os.path.join(folder, s) for s in tmp]
- tag_list += tmp
-
- # 3. select images
- logging.info("[Flickr1M] searching tag: {}".format(tag))
- select_images_list = []
- for idx, _val in enumerate(tag_list):
- tags = read_file(tag_list[idx]).split('\n')
- if tag in tags:
- select_images_list.append(images_list[idx])
-
- logging.info("[Flickr1M] reading images with tag: {}".format(tag))
- images = visualize.read_images(select_images_list, '', n_threads=n_threads, printable=printable)
- return images
-
-
- def load_cyclegan_dataset(filename='summer2winter_yosemite', path='data'):
- """Load images from CycleGAN's database, see `this link <https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/>`__.
-
- Parameters
- ------------
- filename : str
- The dataset you want, see `this link <https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/>`__.
- path : str
- The path that the data is downloaded to, defaults is `data/cyclegan`
-
- Examples
- ---------
- >>> im_train_A, im_train_B, im_test_A, im_test_B = load_cyclegan_dataset(filename='summer2winter_yosemite')
-
- """
- path = os.path.join(path, 'cyclegan')
- url = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/'
-
- if folder_exists(os.path.join(path, filename)) is False:
- logging.info("[*] {} is nonexistent in {}".format(filename, path))
- maybe_download_and_extract(filename + '.zip', path, url, extract=True)
- del_file(os.path.join(path, filename + '.zip'))
-
- def load_image_from_folder(path):
- path_imgs = load_file_list(path=path, regx='\\.jpg', printable=False)
- return visualize.read_images(path_imgs, path=path, n_threads=10, printable=False)
-
- im_train_A = load_image_from_folder(os.path.join(path, filename, "trainA"))
- im_train_B = load_image_from_folder(os.path.join(path, filename, "trainB"))
- im_test_A = load_image_from_folder(os.path.join(path, filename, "testA"))
- im_test_B = load_image_from_folder(os.path.join(path, filename, "testB"))
-
- def if_2d_to_3d(images): # [h, w] --> [h, w, 3]
- for i, _v in enumerate(images):
- if len(images[i].shape) == 2:
- images[i] = images[i][:, :, np.newaxis]
- images[i] = np.tile(images[i], (1, 1, 3))
- return images
-
- im_train_A = if_2d_to_3d(im_train_A)
- im_train_B = if_2d_to_3d(im_train_B)
- im_test_A = if_2d_to_3d(im_test_A)
- im_test_B = if_2d_to_3d(im_test_B)
-
- return im_train_A, im_train_B, im_test_A, im_test_B
-
-
- def download_file_from_google_drive(ID, destination):
- """Download file from Google Drive.
-
- See ``tlx.files.load_celebA_dataset`` for example.
-
- Parameters
- --------------
- ID : str
- The driver ID.
- destination : str
- The destination for save file.
-
- """
- try:
- from tqdm import tqdm
- except ImportError as e:
- print(e)
- raise ImportError("Module tqdm not found. Please install tqdm via pip or other package managers.")
-
- try:
- import requests
- except ImportError as e:
- print(e)
- raise ImportError("Module requests not found. Please install requests via pip or other package managers.")
-
- def save_response_content(response, destination, chunk_size=32 * 1024):
-
- total_size = int(response.headers.get('content-length', 0))
- with open(destination, "wb") as f:
- for chunk in tqdm(response.iter_content(chunk_size), total=total_size, unit='B', unit_scale=True,
- desc=destination):
- if chunk: # filter out keep-alive new chunks
- f.write(chunk)
-
- def get_confirm_token(response):
- for key, value in response.cookies.items():
- if key.startswith('download_warning'):
- return value
- return None
-
- URL = "https://docs.google.com/uc?export=download"
- session = requests.Session()
-
- response = session.get(URL, params={'id': ID}, stream=True)
- token = get_confirm_token(response)
-
- if token:
- params = {'id': ID, 'confirm': token}
- response = session.get(URL, params=params, stream=True)
- save_response_content(response, destination)
-
-
- def load_celebA_dataset(path='data'):
- """Load CelebA dataset
-
- Return a list of image path.
-
- Parameters
- -----------
- path : str
- The path that the data is downloaded to, defaults is ``data/celebA/``.
-
- """
- data_dir = 'celebA'
- filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
- save_path = os.path.join(path, filename)
- image_path = os.path.join(path, data_dir)
- if os.path.exists(image_path):
- logging.info('[*] {} already exists'.format(save_path))
- else:
- exists_or_mkdir(path)
- download_file_from_google_drive(drive_id, save_path)
- zip_dir = ''
- with zipfile.ZipFile(save_path) as zf:
- zip_dir = zf.namelist()[0]
- zf.extractall(path)
- os.remove(save_path)
- os.rename(os.path.join(path, zip_dir), image_path)
-
- data_files = load_file_list(path=image_path, regx='\\.jpg', printable=False)
- for i, _v in enumerate(data_files):
- data_files[i] = os.path.join(image_path, data_files[i])
- return data_files
-
-
- # def load_voc_dataset(path='data', dataset='2012', contain_classes_in_person=False):
- # """Pascal VOC 2007/2012 Dataset.
- #
- # It has 20 objects:
- # aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, diningtable, dog, horse, motorbike, person, pottedplant, sheep, sofa, train, tvmonitor
- # and additional 3 classes : head, hand, foot for person.
- #
- # Parameters
- # -----------
- # path : str
- # The path that the data is downloaded to, defaults is ``data/VOC``.
- # dataset : str
- # The VOC dataset version, `2012`, `2007`, `2007test` or `2012test`. We usually train model on `2007+2012` and test it on `2007test`.
- # contain_classes_in_person : boolean
- # Whether include head, hand and foot annotation, default is False.
- #
- # Returns
- # ---------
- # imgs_file_list : list of str
- # Full paths of all images.
- # imgs_semseg_file_list : list of str
- # Full paths of all maps for semantic segmentation. Note that not all images have this map!
- # imgs_insseg_file_list : list of str
- # Full paths of all maps for instance segmentation. Note that not all images have this map!
- # imgs_ann_file_list : list of str
- # Full paths of all annotations for bounding box and object class, all images have this annotations.
- # classes : list of str
- # Classes in order.
- # classes_in_person : list of str
- # Classes in person.
- # classes_dict : dictionary
- # Class label to integer.
- # n_objs_list : list of int
- # Number of objects in all images in ``imgs_file_list`` in order.
- # objs_info_list : list of str
- # Darknet format for the annotation of all images in ``imgs_file_list`` in order. ``[class_id x_centre y_centre width height]`` in ratio format.
- # objs_info_dicts : dictionary
- # The annotation of all images in ``imgs_file_list``, ``{imgs_file_list : dictionary for annotation}``,
- # format from `TensorFlow/Models/object-detection <https://github.com/tensorflow/models/blob/master/object_detection/create_pascal_tf_record.py>`__.
- #
- # Examples
- # ----------
- # >>> imgs_file_list, imgs_semseg_file_list, imgs_insseg_file_list, imgs_ann_file_list,
- # >>> classes, classes_in_person, classes_dict,
- # >>> n_objs_list, objs_info_list, objs_info_dicts = tlx.files.load_voc_dataset(dataset="2012", contain_classes_in_person=False)
- # >>> idx = 26
- # >>> print(classes)
- # ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
- # >>> print(classes_dict)
- # {'sheep': 16, 'horse': 12, 'bicycle': 1, 'bottle': 4, 'cow': 9, 'sofa': 17, 'car': 6, 'dog': 11, 'cat': 7, 'person': 14, 'train': 18, 'diningtable': 10, 'aeroplane': 0, 'bus': 5, 'pottedplant': 15, 'tvmonitor': 19, 'chair': 8, 'bird': 2, 'boat': 3, 'motorbike': 13}
- # >>> print(imgs_file_list[idx])
- # data/VOC/VOC2012/JPEGImages/2007_000423.jpg
- # >>> print(n_objs_list[idx])
- # 2
- # >>> print(imgs_ann_file_list[idx])
- # data/VOC/VOC2012/Annotations/2007_000423.xml
- # >>> print(objs_info_list[idx])
- # 14 0.173 0.461333333333 0.142 0.496
- # 14 0.828 0.542666666667 0.188 0.594666666667
- # >>> ann = tlx.prepro.parse_darknet_ann_str_to_list(objs_info_list[idx])
- # >>> print(ann)
- # [[14, 0.173, 0.461333333333, 0.142, 0.496], [14, 0.828, 0.542666666667, 0.188, 0.594666666667]]
- # >>> c, b = tlx.prepro.parse_darknet_ann_list_to_cls_box(ann)
- # >>> print(c, b)
- # [14, 14] [[0.173, 0.461333333333, 0.142, 0.496], [0.828, 0.542666666667, 0.188, 0.594666666667]]
- #
- # References
- # -------------
- # - `Pascal VOC2012 Website <http://host.robots.ox.ac.uk/pascal/VOC/voc2012/#devkit>`__.
- # - `Pascal VOC2007 Website <http://host.robots.ox.ac.uk/pascal/VOC/voc2007/>`__.
- #
- # """
- #
- # import xml.etree.ElementTree as ET
- #
- # try:
- # import lxml.etree as etree
- # except ImportError as e:
- # print(e)
- # raise ImportError("Module lxml not found. Please install lxml via pip or other package managers.")
- #
- # path = os.path.join(path, 'VOC')
- #
- # def _recursive_parse_xml_to_dict(xml):
- # """Recursively parses XML contents to python dict.
- #
- # We assume that `object` tags are the only ones that can appear
- # multiple times at the same level of a tree.
- #
- # Args:
- # xml: xml tree obtained by parsing XML file contents using lxml.etree
- #
- # Returns:
- # Python dictionary holding XML contents.
- #
- # """
- # if not xml:
- # # if xml is not None:
- # return {xml.tag: xml.text}
- # result = {}
- # for child in xml:
- # child_result = _recursive_parse_xml_to_dict(child)
- # if child.tag != 'object':
- # result[child.tag] = child_result[child.tag]
- # else:
- # if child.tag not in result:
- # result[child.tag] = []
- # result[child.tag].append(child_result[child.tag])
- # return {xml.tag: result}
- #
- # if dataset == "2012":
- # url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/"
- # tar_filename = "VOCtrainval_11-May-2012.tar"
- # extracted_filename = "VOC2012" # "VOCdevkit/VOC2012"
- # logging.info(" [============= VOC 2012 =============]")
- # elif dataset == "2012test":
- # extracted_filename = "VOC2012test" # "VOCdevkit/VOC2012"
- # logging.info(" [============= VOC 2012 Test Set =============]")
- # logging.info(
- # " \nAuthor: 2012test only have person annotation, so 2007test is highly recommended for testing !\n"
- # )
- # time.sleep(3)
- # if os.path.isdir(os.path.join(path, extracted_filename)) is False:
- # logging.info("For VOC 2012 Test data - online registration required")
- # logging.info(
- # " Please download VOC2012test.tar from: \n register: http://host.robots.ox.ac.uk:8080 \n voc2012 : http://host.robots.ox.ac.uk:8080/eval/challenges/voc2012/ \ndownload: http://host.robots.ox.ac.uk:8080/eval/downloads/VOC2012test.tar"
- # )
- # logging.info(" unzip VOC2012test.tar,rename the folder to VOC2012test and put it into %s" % path)
- # exit()
- # # # http://host.robots.ox.ac.uk:8080/eval/downloads/VOC2012test.tar
- # # url = "http://host.robots.ox.ac.uk:8080/eval/downloads/"
- # # tar_filename = "VOC2012test.tar"
- # elif dataset == "2007":
- # url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/"
- # tar_filename = "VOCtrainval_06-Nov-2007.tar"
- # extracted_filename = "VOC2007"
- # logging.info(" [============= VOC 2007 =============]")
- # elif dataset == "2007test":
- # # http://host.robots.ox.ac.uk/pascal/VOC/voc2007/index.html#testdata
- # # http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
- # url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/"
- # tar_filename = "VOCtest_06-Nov-2007.tar"
- # extracted_filename = "VOC2007test"
- # logging.info(" [============= VOC 2007 Test Set =============]")
- # else:
- # raise Exception("Please set the dataset aug to 2012, 2012test or 2007.")
- #
- # # download dataset
- # if dataset != "2012test":
- # _platform = sys.platform
- # if folder_exists(os.path.join(path, extracted_filename)) is False:
- # logging.info("[VOC] {} is nonexistent in {}".format(extracted_filename, path))
- # maybe_download_and_extract(tar_filename, path, url, extract=True)
- # del_file(os.path.join(path, tar_filename))
- # if dataset == "2012":
- # if _platform == "win32":
- # os.system("mv {}\VOCdevkit\VOC2012 {}\VOC2012".format(path, path))
- # else:
- # os.system("mv {}/VOCdevkit/VOC2012 {}/VOC2012".format(path, path))
- # elif dataset == "2007":
- # if _platform == "win32":
- # os.system("mv {}\VOCdevkit\VOC2007 {}\VOC2007".format(path, path))
- # else:
- # os.system("mv {}/VOCdevkit/VOC2007 {}/VOC2007".format(path, path))
- # elif dataset == "2007test":
- # if _platform == "win32":
- # os.system("mv {}\VOCdevkit\VOC2007 {}\VOC2007test".format(path, path))
- # else:
- # os.system("mv {}/VOCdevkit/VOC2007 {}/VOC2007test".format(path, path))
- # del_folder(os.path.join(path, 'VOCdevkit'))
- # # object classes(labels) NOTE: YOU CAN CUSTOMIZE THIS LIST
- # classes = [
- # "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog",
- # "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"
- # ]
- # if contain_classes_in_person:
- # classes_in_person = ["head", "hand", "foot"]
- # else:
- # classes_in_person = []
- #
- # classes += classes_in_person # use extra 3 classes for person
- #
- # classes_dict = utils.list_string_to_dict(classes)
- # logging.info("[VOC] object classes {}".format(classes_dict))
- #
- # # 1. image path list
- # # folder_imgs = path+"/"+extracted_filename+"/JPEGImages/"
- # folder_imgs = os.path.join(path, extracted_filename, "JPEGImages")
- # imgs_file_list = load_file_list(path=folder_imgs, regx='\\.jpg', printable=False)
- # logging.info("[VOC] {} images found".format(len(imgs_file_list)))
- #
- # imgs_file_list.sort(
- # key=lambda s: int(s.replace('.', ' ').replace('_', '').split(' ')[-2])
- # ) # 2007_000027.jpg --> 2007000027
- #
- # imgs_file_list = [os.path.join(folder_imgs, s) for s in imgs_file_list]
- # # logging.info('IM',imgs_file_list[0::3333], imgs_file_list[-1])
- # if dataset != "2012test":
- # # ======== 2. semantic segmentation maps path list
- # # folder_semseg = path+"/"+extracted_filename+"/SegmentationClass/"
- # folder_semseg = os.path.join(path, extracted_filename, "SegmentationClass")
- # imgs_semseg_file_list = load_file_list(path=folder_semseg, regx='\\.png', printable=False)
- # logging.info("[VOC] {} maps for semantic segmentation found".format(len(imgs_semseg_file_list)))
- # imgs_semseg_file_list.sort(
- # key=lambda s: int(s.replace('.', ' ').replace('_', '').split(' ')[-2])
- # ) # 2007_000032.png --> 2007000032
- # imgs_semseg_file_list = [os.path.join(folder_semseg, s) for s in imgs_semseg_file_list]
- # # logging.info('Semantic Seg IM',imgs_semseg_file_list[0::333], imgs_semseg_file_list[-1])
- # # ======== 3. instance segmentation maps path list
- # # folder_insseg = path+"/"+extracted_filename+"/SegmentationObject/"
- # folder_insseg = os.path.join(path, extracted_filename, "SegmentationObject")
- # imgs_insseg_file_list = load_file_list(path=folder_insseg, regx='\\.png', printable=False)
- # logging.info("[VOC] {} maps for instance segmentation found".format(len(imgs_semseg_file_list)))
- # imgs_insseg_file_list.sort(
- # key=lambda s: int(s.replace('.', ' ').replace('_', '').split(' ')[-2])
- # ) # 2007_000032.png --> 2007000032
- # imgs_insseg_file_list = [os.path.join(folder_insseg, s) for s in imgs_insseg_file_list]
- # # logging.info('Instance Seg IM',imgs_insseg_file_list[0::333], imgs_insseg_file_list[-1])
- # else:
- # imgs_semseg_file_list = []
- # imgs_insseg_file_list = []
- # # 4. annotations for bounding box and object class
- # # folder_ann = path+"/"+extracted_filename+"/Annotations/"
- # folder_ann = os.path.join(path, extracted_filename, "Annotations")
- # imgs_ann_file_list = load_file_list(path=folder_ann, regx='\\.xml', printable=False)
- # logging.info(
- # "[VOC] {} XML annotation files for bounding box and object class found".format(len(imgs_ann_file_list))
- # )
- # imgs_ann_file_list.sort(
- # key=lambda s: int(s.replace('.', ' ').replace('_', '').split(' ')[-2])
- # ) # 2007_000027.xml --> 2007000027
- # imgs_ann_file_list = [os.path.join(folder_ann, s) for s in imgs_ann_file_list]
- # # logging.info('ANN',imgs_ann_file_list[0::3333], imgs_ann_file_list[-1])
- #
- # if dataset == "2012test": # remove unused images in JPEG folder
- # imgs_file_list_new = []
- # for ann in imgs_ann_file_list:
- # ann = os.path.split(ann)[-1].split('.')[0]
- # for im in imgs_file_list:
- # if ann in im:
- # imgs_file_list_new.append(im)
- # break
- # imgs_file_list = imgs_file_list_new
- # logging.info("[VOC] keep %d images" % len(imgs_file_list_new))
- #
- # # parse XML annotations
- # def convert(size, box):
- # dw = 1. / size[0]
- # dh = 1. / size[1]
- # x = (box[0] + box[1]) / 2.0
- # y = (box[2] + box[3]) / 2.0
- # w = box[1] - box[0]
- # h = box[3] - box[2]
- # x = x * dw
- # w = w * dw
- # y = y * dh
- # h = h * dh
- # return x, y, w, h
- #
- # def convert_annotation(file_name):
- # """Given VOC2012 XML Annotations, returns number of objects and info."""
- # in_file = open(file_name)
- # out_file = ""
- # tree = ET.parse(in_file)
- # root = tree.getroot()
- # size = root.find('size')
- # w = int(size.find('width').text)
- # h = int(size.find('height').text)
- # n_objs = 0
- #
- # for obj in root.iter('object'):
- # if dataset != "2012test":
- # difficult = obj.find('difficult').text
- # cls = obj.find('name').text
- # if cls not in classes or int(difficult) == 1:
- # continue
- # else:
- # cls = obj.find('name').text
- # if cls not in classes:
- # continue
- # cls_id = classes.index(cls)
- # xmlbox = obj.find('bndbox')
- # b = (
- # float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
- # float(xmlbox.find('ymax').text)
- # )
- # bb = convert((w, h), b)
- #
- # out_file += str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n'
- # n_objs += 1
- # if cls in "person":
- # for part in obj.iter('part'):
- # cls = part.find('name').text
- # if cls not in classes_in_person:
- # continue
- # cls_id = classes.index(cls)
- # xmlbox = part.find('bndbox')
- # b = (
- # float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text),
- # float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)
- # )
- # bb = convert((w, h), b)
- # # out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
- # out_file += str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n'
- # n_objs += 1
- # in_file.close()
- # return n_objs, out_file
- #
- # logging.info("[VOC] Parsing xml annotations files")
- # n_objs_list = []
- # objs_info_list = [] # Darknet Format list of string
- # objs_info_dicts = {}
- # for idx, ann_file in enumerate(imgs_ann_file_list):
- # n_objs, objs_info = convert_annotation(ann_file)
- # n_objs_list.append(n_objs)
- # objs_info_list.append(objs_info)
- # with tf.io.gfile.GFile(ann_file, 'r') as fid:
- # xml_str = fid.read()
- # xml = etree.fromstring(xml_str)
- # data = _recursive_parse_xml_to_dict(xml)['annotation']
- # objs_info_dicts.update({imgs_file_list[idx]: data})
- #
- # return imgs_file_list, imgs_semseg_file_list, imgs_insseg_file_list, imgs_ann_file_list, classes, classes_in_person, classes_dict, n_objs_list, objs_info_list, objs_info_dicts
-
-
- def load_mpii_pose_dataset(path='data', is_16_pos_only=False):
- """Load MPII Human Pose Dataset.
-
- Parameters
- -----------
- path : str
- The path that the data is downloaded to.
- is_16_pos_only : boolean
- If True, only return the peoples contain 16 pose keypoints. (Usually be used for single person pose estimation)
-
- Returns
- ----------
- img_train_list : list of str
- The image directories of training data.
- ann_train_list : list of dict
- The annotations of training data.
- img_test_list : list of str
- The image directories of testing data.
- ann_test_list : list of dict
- The annotations of testing data.
-
- Examples
- --------
- >>> import pprint
- >>> import tensorlayerx as tlx
- >>> img_train_list, ann_train_list, img_test_list, ann_test_list = tlx.files.load_mpii_pose_dataset()
- >>> image = tlx.vis.read_image(img_train_list[0])
- >>> tlx.vis.draw_mpii_pose_to_image(image, ann_train_list[0], 'image.png')
- >>> pprint.pprint(ann_train_list[0])
-
- References
- -----------
- - `MPII Human Pose Dataset. CVPR 14 <http://human-pose.mpi-inf.mpg.de>`__
- - `MPII Human Pose Models. CVPR 16 <http://pose.mpi-inf.mpg.de>`__
- - `MPII Human Shape, Poselet Conditioned Pictorial Structures and etc <http://pose.mpi-inf.mpg.de/#related>`__
- - `MPII Keyponts and ID <http://human-pose.mpi-inf.mpg.de/#download>`__
- """
- path = os.path.join(path, 'mpii_human_pose')
- logging.info("Load or Download MPII Human Pose > {}".format(path))
-
- # annotation
- url = "http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/"
- tar_filename = "mpii_human_pose_v1_u12_2.zip"
- extracted_filename = "mpii_human_pose_v1_u12_2"
- if folder_exists(os.path.join(path, extracted_filename)) is False:
- logging.info("[MPII] (annotation) {} is nonexistent in {}".format(extracted_filename, path))
- maybe_download_and_extract(tar_filename, path, url, extract=True)
- del_file(os.path.join(path, tar_filename))
-
- # images
- url = "http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/"
- tar_filename = "mpii_human_pose_v1.tar.gz"
- extracted_filename2 = "images"
- if folder_exists(os.path.join(path, extracted_filename2)) is False:
- logging.info("[MPII] (images) {} is nonexistent in {}".format(extracted_filename, path))
- maybe_download_and_extract(tar_filename, path, url, extract=True)
- del_file(os.path.join(path, tar_filename))
-
- # parse annotation, format see http://human-pose.mpi-inf.mpg.de/#download
- logging.info("reading annotations from mat file ...")
- # mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat"))
-
- # def fix_wrong_joints(joint): # https://github.com/mitmul/deeppose/blob/master/datasets/mpii_dataset.py
- # if '12' in joint and '13' in joint and '2' in joint and '3' in joint:
- # if ((joint['12'][0] < joint['13'][0]) and
- # (joint['3'][0] < joint['2'][0])):
- # joint['2'], joint['3'] = joint['3'], joint['2']
- # if ((joint['12'][0] > joint['13'][0]) and
- # (joint['3'][0] > joint['2'][0])):
- # joint['2'], joint['3'] = joint['3'], joint['2']
- # return joint
-
- ann_train_list = []
- ann_test_list = []
- img_train_list = []
- img_test_list = []
-
- def save_joints():
- # joint_data_fn = os.path.join(path, 'data.json')
- # fp = open(joint_data_fn, 'w')
- mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat"))
-
- for _, (anno, train_flag) in enumerate( # all images
- zip(mat['RELEASE']['annolist'][0, 0][0], mat['RELEASE']['img_train'][0, 0][0])):
-
- img_fn = anno['image']['name'][0, 0][0]
- train_flag = int(train_flag)
-
- # print(i, img_fn, train_flag) # DEBUG print all images
-
- if train_flag:
- img_train_list.append(img_fn)
- ann_train_list.append([])
- else:
- img_test_list.append(img_fn)
- ann_test_list.append([])
-
- head_rect = []
- if 'x1' in str(anno['annorect'].dtype):
- head_rect = zip(
- [x1[0, 0] for x1 in anno['annorect']['x1'][0]], [y1[0, 0] for y1 in anno['annorect']['y1'][0]],
- [x2[0, 0] for x2 in anno['annorect']['x2'][0]], [y2[0, 0] for y2 in anno['annorect']['y2'][0]]
- )
- else:
- head_rect = [] # TODO
-
- if 'annopoints' in str(anno['annorect'].dtype):
- annopoints = anno['annorect']['annopoints'][0]
- head_x1s = anno['annorect']['x1'][0]
- head_y1s = anno['annorect']['y1'][0]
- head_x2s = anno['annorect']['x2'][0]
- head_y2s = anno['annorect']['y2'][0]
-
- for annopoint, head_x1, head_y1, head_x2, head_y2 in zip(annopoints, head_x1s, head_y1s, head_x2s,
- head_y2s):
- # if annopoint != []:
- # if len(annopoint) != 0:
- if annopoint.size:
- head_rect = [
- float(head_x1[0, 0]),
- float(head_y1[0, 0]),
- float(head_x2[0, 0]),
- float(head_y2[0, 0])
- ]
-
- # joint coordinates
- annopoint = annopoint['point'][0, 0]
- j_id = [str(j_i[0, 0]) for j_i in annopoint['id'][0]]
- x = [x[0, 0] for x in annopoint['x'][0]]
- y = [y[0, 0] for y in annopoint['y'][0]]
- joint_pos = {}
- for _j_id, (_x, _y) in zip(j_id, zip(x, y)):
- joint_pos[int(_j_id)] = [float(_x), float(_y)]
- # joint_pos = fix_wrong_joints(joint_pos)
-
- # visibility list
- if 'is_visible' in str(annopoint.dtype):
- vis = [v[0] if v.size > 0 else [0] for v in annopoint['is_visible'][0]]
- vis = dict([(k, int(v[0])) if len(v) > 0 else v for k, v in zip(j_id, vis)])
- else:
- vis = None
-
- # if len(joint_pos) == 16:
- if ((is_16_pos_only ==True) and (len(joint_pos) == 16)) or (is_16_pos_only == False):
- # only use image with 16 key points / or use all
- data = {
- 'filename': img_fn,
- 'train': train_flag,
- 'head_rect': head_rect,
- 'is_visible': vis,
- 'joint_pos': joint_pos
- }
- # print(json.dumps(data), file=fp) # py3
- if train_flag:
- ann_train_list[-1].append(data)
- else:
- ann_test_list[-1].append(data)
-
- # def write_line(datum, fp):
- # joints = sorted([[int(k), v] for k, v in datum['joint_pos'].items()])
- # joints = np.array([j for i, j in joints]).flatten()
- #
- # out = [datum['filename']]
- # out.extend(joints)
- # out = [str(o) for o in out]
- # out = ','.join(out)
- #
- # print(out, file=fp)
-
- # def split_train_test():
- # # fp_test = open('data/mpii/test_joints.csv', 'w')
- # fp_test = open(os.path.join(path, 'test_joints.csv'), 'w')
- # # fp_train = open('data/mpii/train_joints.csv', 'w')
- # fp_train = open(os.path.join(path, 'train_joints.csv'), 'w')
- # # all_data = open('data/mpii/data.json').readlines()
- # all_data = open(os.path.join(path, 'data.json')).readlines()
- # N = len(all_data)
- # N_test = int(N * 0.1)
- # N_train = N - N_test
- #
- # print('N:{}'.format(N))
- # print('N_train:{}'.format(N_train))
- # print('N_test:{}'.format(N_test))
- #
- # np.random.seed(1701)
- # perm = np.random.permutation(N)
- # test_indices = perm[:N_test]
- # train_indices = perm[N_test:]
- #
- # print('train_indices:{}'.format(len(train_indices)))
- # print('test_indices:{}'.format(len(test_indices)))
- #
- # for i in train_indices:
- # datum = json.loads(all_data[i].strip())
- # write_line(datum, fp_train)
- #
- # for i in test_indices:
- # datum = json.loads(all_data[i].strip())
- # write_line(datum, fp_test)
-
- save_joints()
- # split_train_test() #
-
- # read images dir
- logging.info("reading images list ...")
- img_dir = os.path.join(path, extracted_filename2)
- _img_list = load_file_list(path=os.path.join(path, extracted_filename2), regx='\\.jpg', printable=False)
- # ann_list = json.load(open(os.path.join(path, 'data.json')))
- for i, im in enumerate(img_train_list):
- if im not in _img_list:
- print('missing training image {} in {} (remove from img(ann)_train_list)'.format(im, img_dir))
- # img_train_list.remove(im)
- del img_train_list[i]
- del ann_train_list[i]
- for i, im in enumerate(img_test_list):
- if im not in _img_list:
- print('missing testing image {} in {} (remove from img(ann)_test_list)'.format(im, img_dir))
- # img_test_list.remove(im)
- del img_train_list[i]
- del ann_train_list[i]
-
- # check annotation and images
- n_train_images = len(img_train_list)
- n_test_images = len(img_test_list)
- n_images = n_train_images + n_test_images
- logging.info("n_images: {} n_train_images: {} n_test_images: {}".format(n_images, n_train_images, n_test_images))
- n_train_ann = len(ann_train_list)
- n_test_ann = len(ann_test_list)
- n_ann = n_train_ann + n_test_ann
- logging.info("n_ann: {} n_train_ann: {} n_test_ann: {}".format(n_ann, n_train_ann, n_test_ann))
- n_train_people = len(sum(ann_train_list, []))
- n_test_people = len(sum(ann_test_list, []))
- n_people = n_train_people + n_test_people
- logging.info("n_people: {} n_train_people: {} n_test_people: {}".format(n_people, n_train_people, n_test_people))
- # add path to all image file name
- for i, value in enumerate(img_train_list):
- img_train_list[i] = os.path.join(img_dir, value)
- for i, value in enumerate(img_test_list):
- img_test_list[i] = os.path.join(img_dir, value)
- return img_train_list, ann_train_list, img_test_list, ann_test_list
-
-
- def save_npz(save_list=None, name='model.npz'):
- """Input parameters and the file name, save parameters into .npz file. Use tlx.utils.load_npz() to restore.
-
- Parameters
- ----------
- save_list : list of tensor
- A list of parameters (tensor) to be saved.
- name : str
- The name of the `.npz` file.
-
- Examples
- --------
- Save model to npz
-
- >>> tlx.files.save_npz(network.all_weights, name='model.npz')
-
- Load model from npz (Method 1)
-
- >>> load_params = tlx.files.load_npz(name='model.npz')
- >>> tlx.files.assign_weights(load_params, network)
-
- Load model from npz (Method 2)
-
- >>> tlx.files.load_and_assign_npz(name='model.npz', network=network)
-
- References
- ----------
- `Saving dictionary using numpy <http://stackoverflow.com/questions/22315595/saving-dictionary-of-header-information-using-numpy-savez>`__
-
- """
- logging.info("[*] Saving TLX weights into %s" % name)
- if save_list is None:
- save_list = []
-
- if tlx.BACKEND == 'tensorflow':
- save_list_var = tf_variables_to_numpy(save_list)
- elif tlx.BACKEND == 'mindspore':
- save_list_var = ms_variables_to_numpy(save_list)
- elif tlx.BACKEND == 'paddle':
- save_list_var = pd_variables_to_numpy(save_list)
- elif tlx.BACKEND == 'torch':
- save_list_var = th_variables_to_numpy(save_list)
- else:
- raise NotImplementedError("This backend is not supported")
- # Number by length
- save_list_names = [str(i) for i in range(len(save_list_var))]
- save_var_dict = {save_list_names[idx]: val for idx, val in enumerate(save_list_var)}
- np.savez(name, **save_var_dict)
- save_list_var = None
- save_var_dict = None
- del save_list_var
- del save_var_dict
- logging.info("[*] Saved")
-
-
- def load_npz(path='', name='model.npz'):
- """Load the parameters of a Model saved by tlx.files.save_npz().
-
- Parameters
- ----------
- path : str
- Folder path to `.npz` file.
- name : str
- The name of the `.npz` file.
-
- Returns
- --------
- list of array
- A list of parameters in order.
-
- Examples
- --------
- - See ``tlx.files.save_npz``
-
- References
- ----------
- - `Saving dictionary using numpy <http://stackoverflow.com/questions/22315595/saving-dictionary-of-header-information-using-numpy-savez>`__
-
- """
- d = np.load(os.path.join(path, name), allow_pickle=True)
- return [d[str(i)] for i in range(len(d))]
-
-
- def assign_params(**kwargs):
- raise Exception("please change assign_params --> assign_weights")
-
-
- def assign_weights(weights, network):
- """Assign the given parameters to the TensorLayer network.
-
- Parameters
- ----------
- weights : list of array
- A list of model weights (array) in order.
- network : :class:`Layer`
- The network to be assigned.
-
- Returns
- --------
- 1) list of operations if in graph mode
- A list of tf ops in order that assign weights. Support sess.run(ops) manually.
- 2) list of tf variables if in eager mode
- A list of tf variables (assigned weights) in order.
-
- Examples
- --------
-
- References
- ----------
- - `Assign value to a TensorFlow variable <http://stackoverflow.com/questions/34220532/how-to-assign-value-to-a-tensorflow-variable>`__
-
- """
- ops = []
- if tlx.BACKEND == 'tensorflow':
- for idx, param in enumerate(weights):
- ops.append(network.all_weights[idx].assign(param))
-
- elif tlx.BACKEND == 'mindspore':
-
- class Assign_net(Cell):
-
- def __init__(self, y):
- super(Assign_net, self).__init__()
- self.y = y
-
- def construct(self, x):
- Assign()(self.y, x)
-
- for idx, param in enumerate(weights):
- assign_param = Tensor(param, dtype=ms.float32)
- # net = Assign_net(network.all_weights[idx])
- # net(assign_param)
- Assign()(network.all_weights[idx], assign_param)
- elif tlx.BACKEND == 'paddle':
- for idx, param in enumerate(weights):
- assign_pd_variable(network.all_weights[idx], param)
- elif tlx.BACKEND == 'torch':
- for idx, param in enumerate(weights):
- assign_th_variable(network.all_weights[idx], param)
-
- else:
- raise NotImplementedError("This backend is not supported")
- return ops
-
-
- def load_and_assign_npz(name=None, network=None):
- """Load model from npz and assign to a network.
-
- Parameters
- -------------
- name : str
- The name of the `.npz` file.
- network : :class:`Model`
- The network to be assigned.
-
- Examples
- --------
- - See ``tlx.files.save_npz``
-
- """
- if network is None:
- raise ValueError("network is None.")
-
- if not os.path.exists(name):
- logging.error("file {} doesn't exist.".format(name))
- return False
- else:
- weights = load_npz(name=name)
- assign_weights(weights, network)
- logging.info("[*] Load {} SUCCESS!".format(name))
-
-
- def save_npz_dict(save_list=None, name='model.npz'):
- """Input parameters and the file name, save parameters as a dictionary into .npz file.
-
- Use ``tlx.files.load_and_assign_npz_dict()`` to restore.
-
- Parameters
- ----------
- save_list : list of parameters
- A list of parameters (tensor) to be saved.
- name : str
- The name of the `.npz` file.
-
- """
- if save_list is None:
- save_list = []
- if tlx.BACKEND != 'torch':
- save_list_names = [tensor.name for tensor in save_list]
-
- if tlx.BACKEND == 'tensorflow':
- save_list_var = tf_variables_to_numpy(save_list)
- elif tlx.BACKEND == 'mindspore':
- save_list_var = ms_variables_to_numpy(save_list)
- elif tlx.BACKEND == 'paddle':
- save_list_var = pd_variables_to_numpy(save_list)
- elif tlx.BACKEND == 'torch':
- save_list_names = []
- save_list_var = []
- for named, values in save_list:
- save_list_names.append(named)
- save_list_var.append(values.cpu().detach().numpy())
- else:
- raise NotImplementedError('Not implemented')
- save_var_dict = {save_list_names[idx]: val for idx, val in enumerate(save_list_var)}
- np.savez(name, **save_var_dict)
- save_list_var = None
- save_var_dict = None
- del save_list_var
- del save_var_dict
- logging.info("[*] Model saved in npz_dict %s" % name)
-
-
- def load_and_assign_npz_dict(name='model.npz', network=None, skip=False):
- """Restore the parameters saved by ``tlx.files.save_npz_dict()``.
-
- Parameters
- -------------
- name : str
- The name of the `.npz` file.
- network : :class:`Model`
- The network to be assigned.
- skip : boolean
- If 'skip' == True, loaded weights whose name is not found in network's weights will be skipped.
- If 'skip' is False, error will be raised when mismatch is found. Default False.
-
- """
- if not os.path.exists(name):
- logging.error("file {} doesn't exist.".format(name))
- return False
-
- weights = np.load(name, allow_pickle=True)
- if len(weights.keys()) != len(set(weights.keys())):
- raise Exception("Duplication in model npz_dict %s" % name)
-
- if tlx.BACKEND == 'torch':
- net_weights_name = [n for n, v in network.named_parameters()]
- torch_weights_dict = {n: v for n, v in network.named_parameters()}
- else:
- net_weights_name = [w.name for w in network.all_weights]
-
- for key in weights.keys():
- if key not in net_weights_name:
- if skip:
- logging.warning("Weights named '%s' not found in network. Skip it." % key)
- else:
- raise RuntimeError(
- "Weights named '%s' not found in network. Hint: set argument skip=Ture "
- "if you want to skip redundant or mismatch weights." % key
- )
- else:
- if tlx.BACKEND == 'tensorflow':
- assign_tf_variable(network.all_weights[net_weights_name.index(key)], weights[key])
- elif tlx.BACKEND == 'mindspore':
- assign_param = Tensor(weights[key], dtype=ms.float32)
- assign_ms_variable(network.all_weights[net_weights_name.index(key)], assign_param)
- elif tlx.BACKEND == 'paddle':
- assign_pd_variable(network.all_weights[net_weights_name.index(key)], weights[key])
- elif tlx.BACKEND == 'torch':
- assign_th_variable(torch_weights_dict[key], weights[key])
- else:
- raise NotImplementedError('Not implemented')
-
- logging.info("[*] Model restored from npz_dict %s" % name)
-
-
- def save_ckpt(mode_name='model.ckpt', save_dir='checkpoint', var_list=None, global_step=None, printable=False):
- """Save parameters into `ckpt` file.
-
- Parameters
- ------------
- mode_name : str
- The name of the model, default is ``model.ckpt``.
- save_dir : str
- The path / file directory to the `ckpt`, default is ``checkpoint``.
- var_list : list of tensor
- The parameters / variables (tensor) to be saved. If empty, save all global variables (default).
- global_step : int or None
- Step number.
- printable : boolean
- Whether to print all parameters information.
-
- See Also
- --------
- load_ckpt
-
- """
-
- if var_list is None:
- if sess is None:
- # FIXME: not sure whether global variables can be accessed in eager mode
- raise ValueError(
- "If var_list is None, sess must be specified. "
- "In eager mode, can not access global variables easily. "
- )
- var_list = []
-
- ckpt_file = os.path.join(save_dir, mode_name)
- if var_list == []:
- var_list = tf.global_variables()
-
- logging.info("[*] save %s n_weights: %d" % (ckpt_file, len(var_list)))
-
- if printable:
- for idx, v in enumerate(var_list):
- logging.info(" param {:3}: {:15} {}".format(idx, v.name, str(v.get_shape())))
-
- if sess:
- # graph mode
- saver = tf.train.Saver(var_list)
- saver.save(sess, ckpt_file, global_step=global_step)
- else:
- # eager mode
- # saver = tfes.Saver(var_list)
- # saver.save(ckpt_file, global_step=global_step)
- # TODO: tf2.0 not stable, cannot import tensorflow.contrib.eager.python.saver
- pass
-
-
- def load_ckpt(sess=None, mode_name='model.ckpt', save_dir='checkpoint', var_list=None, is_latest=True, printable=False):
- """Load parameters from `ckpt` file.
-
- Parameters
- ------------
- sess : Session
- TensorFlow Session.
- mode_name : str
- The name of the model, default is ``model.ckpt``.
- save_dir : str
- The path / file directory to the `ckpt`, default is ``checkpoint``.
- var_list : list of tensor
- The parameters / variables (tensor) to be saved. If empty, save all global variables (default).
- is_latest : boolean
- Whether to load the latest `ckpt`, if False, load the `ckpt` with the name of ```mode_name``.
- printable : boolean
- Whether to print all parameters information.
-
- Examples
- ----------
- - Save all global parameters.
-
- >>> tlx.files.save_ckpt(sess=sess, mode_name='model.ckpt', save_dir='model', printable=True)
-
- - Save specific parameters.
-
- >>> tlx.files.save_ckpt(sess=sess, mode_name='model.ckpt', var_list=net.all_params, save_dir='model', printable=True)
-
- - Load latest ckpt.
-
- >>> tlx.files.load_ckpt(sess=sess, var_list=net.all_params, save_dir='model', printable=True)
-
- - Load specific ckpt.
-
- >>> tlx.files.load_ckpt(sess=sess, mode_name='model.ckpt', var_list=net.all_params, save_dir='model', is_latest=False, printable=True)
-
- """
- # if sess is None:
- # raise ValueError("session is None.")
- if var_list is None:
- if sess is None:
- # FIXME: not sure whether global variables can be accessed in eager mode
- raise ValueError(
- "If var_list is None, sess must be specified. "
- "In eager mode, can not access global variables easily. "
- )
- var_list = []
-
- if is_latest:
- ckpt_file = tf.train.latest_checkpoint(save_dir)
- else:
- ckpt_file = os.path.join(save_dir, mode_name)
-
- if not var_list:
- var_list = tf.global_variables()
-
- logging.info("[*] load %s n_weights: %d" % (ckpt_file, len(var_list)))
-
- if printable:
- for idx, v in enumerate(var_list):
- logging.info(" weights {:3}: {:15} {}".format(idx, v.name, str(v.get_shape())))
-
- try:
- if sess:
- # graph mode
- saver = tf.train.Saver(var_list)
- saver.restore(sess, ckpt_file)
- else:
- # eager mode
- # saver = tfes.Saver(var_list)
- # saver.restore(ckpt_file)
- # TODO: tf2.0 not stable, cannot import tensorflow.contrib.eager.python.saver
- pass
-
- except Exception as e:
- logging.info(e)
- logging.info("[*] load ckpt fail ...")
-
-
- def save_any_to_npy(save_dict=None, name='file.npy'):
- """Save variables to `.npy` file.
-
- Parameters
- ------------
- save_dict : directory
- The variables to be saved.
- name : str
- File name.
-
- Examples
- ---------
- >>> tlx.files.save_any_to_npy(save_dict={'data': ['a','b']}, name='test.npy')
- >>> data = tlx.files.load_npy_to_any(name='test.npy')
- >>> print(data)
- {'data': ['a','b']}
-
- """
- if save_dict is None:
- save_dict = {}
- np.save(name, save_dict)
-
-
- def load_npy_to_any(path='', name='file.npy'):
- """Load `.npy` file.
-
- Parameters
- ------------
- path : str
- Path to the file (optional).
- name : str
- File name.
-
- Examples
- ---------
- - see tlx.files.save_any_to_npy()
-
- """
- file_path = os.path.join(path, name)
- try:
- return np.load(file_path, allow_pickle=True).item()
- except Exception:
- return np.load(file_path, allow_pickle=True)
- raise Exception("[!] Fail to load %s" % file_path)
-
-
- def file_exists(filepath):
- """Check whether a file exists by given file path."""
- return os.path.isfile(filepath)
-
-
- def folder_exists(folderpath):
- """Check whether a folder exists by given folder path."""
- return os.path.isdir(folderpath)
-
-
- def del_file(filepath):
- """Delete a file by given file path."""
- os.remove(filepath)
-
-
- def del_folder(folderpath):
- """Delete a folder by given folder path."""
- shutil.rmtree(folderpath)
-
-
- def read_file(filepath):
- """Read a file and return a string.
-
- Examples
- ---------
- >>> data = tlx.files.read_file('data.txt')
-
- """
- with open(filepath, 'r') as afile:
- return afile.read()
-
-
- def load_file_list(path=None, regx='\.jpg', printable=True, keep_prefix=False):
- r"""Return a file list in a folder by given a path and regular expression.
-
- Parameters
- ----------
- path : str or None
- A folder path, if `None`, use the current directory.
- regx : str
- The regx of file name.
- printable : boolean
- Whether to print the files infomation.
- keep_prefix : boolean
- Whether to keep path in the file name.
-
- Examples
- ----------
- >>> file_list = tlx.files.load_file_list(path=None, regx='w1pre_[0-9]+\.(npz)')
-
- """
- if path is None:
- path = os.getcwd()
- file_list = os.listdir(path)
- return_list = []
- for _, f in enumerate(file_list):
- if re.search(regx, f):
- return_list.append(f)
- # return_list.sort()
- if keep_prefix:
- for i, f in enumerate(return_list):
- return_list[i] = os.path.join(path, f)
-
- if printable:
- logging.info('Match file list = %s' % return_list)
- logging.info('Number of files = %d' % len(return_list))
- return return_list
-
-
- def load_folder_list(path=""):
- """Return a folder list in a folder by given a folder path.
-
- Parameters
- ----------
- path : str
- A folder path.
-
- """
- return [os.path.join(path, o) for o in os.listdir(path) if os.path.isdir(os.path.join(path, o))]
-
-
- def exists_or_mkdir(path, verbose=True):
- """Check a folder by given name, if not exist, create the folder and return False,
- if directory exists, return True.
-
- Parameters
- ----------
- path : str
- A folder path.
- verbose : boolean
- If True (default), prints results.
-
- Returns
- --------
- boolean
- True if folder already exist, otherwise, returns False and create the folder.
-
- Examples
- --------
- >>> tlx.files.exists_or_mkdir("checkpoints/train")
-
- """
- if not os.path.exists(path):
- if verbose:
- logging.info("[*] creates %s ..." % path)
- os.makedirs(path)
- return False
- else:
- if verbose:
- logging.info("[!] %s exists ..." % path)
- return True
-
-
- def maybe_download_and_extract(filename, working_directory, url_source, extract=False, expected_bytes=None):
- """Checks if file exists in working_directory otherwise tries to dowload the file,
- and optionally also tries to extract the file if format is ".zip" or ".tar"
-
- Parameters
- -----------
- filename : str
- The name of the (to be) dowloaded file.
- working_directory : str
- A folder path to search for the file in and dowload the file to
- url : str
- The URL to download the file from
- extract : boolean
- If True, tries to uncompress the dowloaded file is ".tar.gz/.tar.bz2" or ".zip" file, default is False.
- expected_bytes : int or None
- If set tries to verify that the downloaded file is of the specified size, otherwise raises an Exception, defaults is None which corresponds to no check being performed.
-
- Returns
- ----------
- str
- File path of the dowloaded (uncompressed) file.
-
- Examples
- --------
- >>> down_file = tlx.files.maybe_download_and_extract(filename='train-images-idx3-ubyte.gz',
- ... working_directory='data/',
- ... url_source='http://yann.lecun.com/exdb/mnist/')
- >>> tlx.files.maybe_download_and_extract(filename='ADEChallengeData2016.zip',
- ... working_directory='data/',
- ... url_source='http://sceneparsing.csail.mit.edu/data/',
- ... extract=True)
-
- """
-
- # We first define a download function, supporting both Python 2 and 3.
- def _download(filename, working_directory, url_source):
-
- progress_bar = progressbar.ProgressBar()
-
- def _dlProgress(count, blockSize, totalSize, pbar=progress_bar):
- if (totalSize != 0):
-
- if not pbar.max_value:
- totalBlocks = math.ceil(float(totalSize) / float(blockSize))
- pbar.max_value = int(totalBlocks)
-
- pbar.update(count, force=True)
-
- filepath = os.path.join(working_directory, filename)
-
- logging.info('Downloading %s...\n' % filename)
-
- urlretrieve(url_source + filename, filepath, reporthook=_dlProgress)
-
- exists_or_mkdir(working_directory, verbose=False)
- filepath = os.path.join(working_directory, filename)
-
- if not os.path.exists(filepath):
-
- _download(filename, working_directory, url_source)
- statinfo = os.stat(filepath)
- logging.info('Succesfully downloaded %s %s bytes.' % (filename, statinfo.st_size)) # , 'bytes.')
- if (not (expected_bytes is None) and (expected_bytes != statinfo.st_size)):
- raise Exception('Failed to verify ' + filename + '. Can you get to it with a browser?')
- if (extract):
- if tarfile.is_tarfile(filepath):
- logging.info('Trying to extract tar file')
- tarfile.open(filepath, 'r').extractall(working_directory)
- logging.info('... Success!')
- elif zipfile.is_zipfile(filepath):
- logging.info('Trying to extract zip file')
- with zipfile.ZipFile(filepath) as zf:
- zf.extractall(working_directory)
- logging.info('... Success!')
- else:
- logging.info("Unknown compression_format only .tar.gz/.tar.bz2/.tar and .zip supported")
- return filepath
-
-
- def natural_keys(text):
- """Sort list of string with number in human order.
-
- Examples
- ----------
- >>> l = ['im1.jpg', 'im31.jpg', 'im11.jpg', 'im21.jpg', 'im03.jpg', 'im05.jpg']
- >>> l.sort(key=tlx.files.natural_keys)
- ['im1.jpg', 'im03.jpg', 'im05', 'im11.jpg', 'im21.jpg', 'im31.jpg']
- >>> l.sort() # that is what we dont want
- ['im03.jpg', 'im05', 'im1.jpg', 'im11.jpg', 'im21.jpg', 'im31.jpg']
-
- References
- ----------
- - `link <http://nedbatchelder.com/blog/200712/human_sorting.html>`__
-
- """
-
- # - alist.sort(key=natural_keys) sorts in human order
- # http://nedbatchelder.com/blog/200712/human_sorting.html
- # (See Toothy's implementation in the comments)
- def atoi(text):
- return int(text) if text.isdigit() else text
-
- return [atoi(c) for c in re.split('(\d+)', text)]
-
-
- # Visualizing npz files
- def npz_to_W_pdf(path=None, regx='w1pre_[0-9]+\.(npz)'):
- r"""Convert the first weight matrix of `.npz` file to `.pdf` by using `tlx.visualize.W()`.
-
- Parameters
- ----------
- path : str
- A folder path to `npz` files.
- regx : str
- Regx for the file name.
-
- Examples
- ---------
- Convert the first weight matrix of w1_pre...npz file to w1_pre...pdf.
-
- >>> tlx.files.npz_to_W_pdf(path='/Users/.../npz_file/', regx='w1pre_[0-9]+\.(npz)')
-
- """
- file_list = load_file_list(path=path, regx=regx)
- for f in file_list:
- W = load_npz(path, f)[0]
- logging.info("%s --> %s" % (f, f.split('.')[0] + '.pdf'))
- visualize.draw_weights(W, second=10, saveable=True, name=f.split('.')[0], fig_idx=2012)
-
-
- def tf_variables_to_numpy(variables):
- """Convert TF tensor or a list of tensors into a list of numpy array"""
- if not isinstance(variables, list):
- var_list = [variables]
- else:
- var_list = variables
-
- results = [v.numpy() for v in var_list]
- return results
-
-
- def ms_variables_to_numpy(variables):
- """Convert MS tensor or list of tensors into a list of numpy array"""
- if not isinstance(variables, list):
- var_list = [variables]
- else:
- var_list = variables
-
- results = [v.data.asnumpy() for v in var_list]
- return results
-
-
- def pd_variables_to_numpy(variables):
- if not isinstance(variables, list):
- var_list = [variables]
- else:
- var_list = variables
-
- results = [v.numpy() for v in var_list]
- return results
-
-
- def th_variables_to_numpy(variables):
- if not isinstance(variables, list):
- var_list = [variables]
- else:
- var_list = variables
- results = [v.cpu().detach().numpy() for v in var_list]
- return results
-
-
- def assign_tf_variable(variable, value):
- """Assign value to a TF variable"""
- variable.assign(value)
-
-
- def assign_ms_variable(variable, value):
-
- class Assign_net(Cell):
-
- def __init__(self, y):
- super(Assign_net, self).__init__()
- self.y = y
-
- def construct(self, x):
- Assign()(self.y, x)
-
- # net = Assign_net(variable)
- # net(value)
- Assign()(variable, value)
-
-
- def assign_pd_variable(variable, value):
- variable.set_value(value)
-
-
- def assign_th_variable(variable, value):
- variable.data = torch.as_tensor(value)
-
-
- def _save_weights_to_hdf5_group(f, save_list):
- """
- Save layer/model weights into hdf5 group recursively.
-
- Parameters
- ----------
- f: hdf5 group
- A hdf5 group created by h5py.File() or create_group().
- layers: list
- A list of layers to save weights.
-
- """
-
- if save_list is None:
- save_list = []
- if tlx.BACKEND != 'torch':
- save_list_names = [tensor.name for tensor in save_list]
-
- if tlx.BACKEND == 'tensorflow':
- save_list_var = tf_variables_to_numpy(save_list)
- elif tlx.BACKEND == 'mindspore':
- save_list_var = ms_variables_to_numpy(save_list)
- elif tlx.BACKEND == 'paddle':
- save_list_var = pd_variables_to_numpy(save_list)
- elif tlx.BACKEND == 'torch':
- save_list_names = []
- save_list_var = []
- for named, values in save_list:
- save_list_names.append(named)
- save_list_var.append(values.cpu().detach().numpy())
- else:
- raise NotImplementedError('Not implemented')
- save_var_dict = {save_list_names[idx]: val for idx, val in enumerate(save_list_var)}
-
- g = f.create_group('model_parameters')
- for k in save_var_dict.keys():
- val_dataset = g.create_dataset('.'.join(k.split('/')), data=save_var_dict[k])
-
- save_list_var = None
- save_var_dict = None
- del save_list_var
- del save_var_dict
- logging.info("[*] Model saved in hdf5.")
-
-
- def _load_weights_from_hdf5_group_in_order(f, layers):
- """
- Load layer weights from a hdf5 group sequentially.
-
- Parameters
- ----------
- f: hdf5 group
- A hdf5 group created by h5py.File() or create_group().
- layers: list
- A list of layers to load weights.
-
- """
- layer_names = [n.decode('utf8') for n in f.attrs["layer_names"]]
-
- for idx, name in enumerate(layer_names):
- g = f[name]
- layer = layers[idx]
- if isinstance(layer, tlx.model.Model):
- _load_weights_from_hdf5_group_in_order(g, layer.all_layers)
- elif isinstance(layer, tlx.nn.ModelLayer):
- _load_weights_from_hdf5_group_in_order(g, layer.model.all_layers)
- elif isinstance(layer, tlx.nn.layers.ModuleList):
- _load_weights_from_hdf5_group_in_order(g, layer.layers)
- elif isinstance(layer, tlx.nn.Layer):
- weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
- for iid, w_name in enumerate(weight_names):
- assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]))
- else:
- raise Exception("Only layer or model can be saved into hdf5.")
- if idx == len(layers) - 1:
- break
-
-
- def _load_weights_from_hdf5_group(f, layers, skip=False):
- """
- Load layer weights from a hdf5 group by layer name.
-
- Parameters
- ----------
- f: hdf5 group
- A hdf5 group created by h5py.File() or create_group().
- layers: list
- A list of layers to load weights.
- skip : boolean
- If 'skip' == True, loaded layer whose name is not found in 'layers' will be skipped. If 'skip' is False,
- error will be raised when mismatch is found. Default False.
-
- """
- layer_names = [n.decode('utf8') for n in f.attrs["layer_names"]]
- layer_index = {layer.name: layer for layer in layers}
-
- for idx, name in enumerate(layer_names):
- if name not in layer_index.keys():
- if skip:
- logging.warning("Layer named '%s' not found in network. Skip it." % name)
- else:
- raise RuntimeError(
- "Layer named '%s' not found in network. Hint: set argument skip=Ture "
- "if you want to skip redundant or mismatch Layers." % name
- )
- else:
- g = f[name]
- layer = layer_index[name]
- if isinstance(layer, tlx.model.Model):
- _load_weights_from_hdf5_group(g, layer.all_layers, skip)
- elif isinstance(layer, tlx.nn.ModelLayer):
- _load_weights_from_hdf5_group(g, layer.model.all_layers, skip)
- elif isinstance(layer, tlx.nn.ModuleList):
- _load_weights_from_hdf5_group(g, layer.layers, skip)
- elif isinstance(layer, tlx.nn.Layer):
- weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
- for iid, w_name in enumerate(weight_names):
- # FIXME : this is only for compatibility
- if isinstance(layer, tlx.nn.BatchNorm) and np.asarray(g[w_name]).ndim > 1:
- assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]).squeeze())
- continue
- assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]))
- else:
- raise Exception("Only layer or model can be saved into hdf5.")
-
-
- def save_weights_to_hdf5(save_list, filepath):
- """Input filepath and save weights in hdf5 format.
-
- Parameters
- ----------
- filepath : str
- Filename to which the weights will be saved.
- network : Model
- TL model.
-
- Returns
- -------
-
- """
- logging.info("[*] Saving TL weights into %s" % filepath)
-
- with h5py.File(filepath, 'w') as f:
- _save_weights_to_hdf5_group(f, save_list)
-
- logging.info("[*] Saved")
-
-
- def load_hdf5_to_weights_in_order(filepath, network, skip=False):
- """Load weights sequentially from a given file of hdf5 format
-
- Parameters
- ----------
- filepath : str
- Filename to which the weights will be loaded, should be of hdf5 format.
- network : Model
- TL model.
-
- Notes:
- If the file contains more weights than given 'weights', then the redundant ones will be ignored
- if all previous weights match perfectly.
-
- Returns
- -------
-
- """
- f = h5py.File(filepath, 'r')
- weights = f['model_parameters']
- if len(weights.keys()) != len(set(weights.keys())):
- raise Exception("Duplication in model npz_dict %s" % name)
-
- if tlx.BACKEND == 'torch':
- net_weights_name = [n for n, v in network.named_parameters()]
- torch_weights_dict = {n: v for n, v in network.named_parameters()}
- else:
- net_weights_name = [w.name for w in network.all_weights]
-
- for key in weights.keys():
- key_t = '/'.join(key.split('.'))
- if key_t not in net_weights_name:
- if skip:
- logging.warning("Weights named '%s' not found in network. Skip it." % key)
- else:
- raise RuntimeError(
- "Weights named '%s' not found in network. Hint: set argument skip=Ture "
- "if you want to skip redundant or mismatch weights." % key
- )
- else:
- if tlx.BACKEND == 'tensorflow':
- assign_tf_variable(network.all_weights[net_weights_name.index(key_t)], weights[key])
- elif tlx.BACKEND == 'mindspore':
- assign_param = Tensor(weights[key], dtype=ms.float32)
- assign_ms_variable(network.all_weights[net_weights_name.index(key_t)], assign_param)
- elif tlx.BACKEND == 'paddle':
- assign_pd_variable(network.all_weights[net_weights_name.index(key_t)], weights[key])
- elif tlx.BACKEND == 'torch':
- assign_th_variable(torch_weights_dict[key_t], weights[key])
- else:
- raise NotImplementedError('Not implemented')
- f.close()
- logging.info("[*] Load %s SUCCESS!" % filepath)
-
-
- def load_hdf5_to_weights(filepath, network, skip=False):
- """Load weights by name from a given file of hdf5 format
-
- Parameters
- ----------
- filepath : str
- Filename to which the weights will be loaded, should be of hdf5 format.
- network : Model
- TL model.
- skip : bool
- If 'skip' == True, loaded weights whose name is not found in 'weights' will be skipped. If 'skip' is False,
- error will be raised when mismatch is found. Default False.
-
- Returns
- -------
-
- """
- f = h5py.File(filepath, 'r')
- try:
- layer_names = [n.decode('utf8') for n in f.attrs["layer_names"]]
- except Exception:
- raise NameError(
- "The loaded hdf5 file needs to have 'layer_names' as attributes. "
- "Please check whether this hdf5 file is saved from TL."
- )
-
- net_index = {layer.name: layer for layer in network.all_layers}
-
- if len(network.all_layers) != len(layer_names):
- logging.warning(
- "Number of weights mismatch."
- "Trying to load a saved file with " + str(len(layer_names)) + " layers into a model with " +
- str(len(network.all_layers)) + " layers."
- )
-
- # check mismatch form network weights to hdf5
- for name in net_index.keys():
- if name not in layer_names:
- logging.warning("Network layer named '%s' not found in loaded hdf5 file. It will be skipped." % name)
-
- # load weights from hdf5 to network
- _load_weights_from_hdf5_group(f, network.all_layers, skip)
-
- f.close()
- logging.info("[*] Load %s SUCCESS!" % filepath)
-
-
- def load_and_assign_ckpt(model_dir, network=None, skip=True):
- """Load weights by name from a given file of ckpt format
-
- Parameters
- ----------
- model_dir : str
- Filename to which the weights will be loaded, should be of ckpt format.
- Examples: model_dir = /root/cnn_model/
- network : Model
- TL model.
- skip : bool
- If 'skip' == True, loaded weights whose name is not found in 'weights' will be skipped. If 'skip' is False,
- error will be raised when mismatch is found. Default False.
-
- Returns
- -------
-
- """
- model_dir = model_dir
- model_path = None
- for root, dirs, files in os.walk(model_dir):
- for file in files:
- filename, extension = os.path.splitext(file)
- if extension in ['.data-00000-of-00001', '.index', '.meta']:
- model_path = model_dir + '/' + filename
- break
- if model_path == None:
- raise Exception('The ckpt file is not found')
-
- reader = pywrap_tensorflow.NewCheckpointReader(model_path)
- var_to_shape_map = reader.get_variable_to_shape_map()
-
- net_weights_name = [w.name for w in network.all_weights]
-
- for key in var_to_shape_map:
- if key not in net_weights_name:
- if skip:
- logging.warning("Weights named '%s' not found in network. Skip it." % key)
- else:
- raise RuntimeError(
- "Weights named '%s' not found in network. Hint: set argument skip=Ture "
- "if you want to skip redundant or mismatch weights." % key
- )
- else:
- assign_tf_variable(network.all_weights[net_weights_name.index(key)], reader.get_tensor(key))
- logging.info("[*] Model restored from ckpt %s" % filename)
-
-
- def ckpt_to_npz_dict(model_dir, save_name='model.npz'):
- """ Save ckpt weights to npz file
-
- Parameters
- ----------
- model_dir : str
- Filename to which the weights will be loaded, should be of ckpt format.
- Examples: model_dir = /root/cnn_model/
- save_name : str
- The save_name of the `.npz` file.
-
- Returns
- -------
-
- """
- model_dir = model_dir
- model_path = None
- for root, dirs, files in os.walk(model_dir):
- for file in files:
- filename, extension = os.path.splitext(file)
- if extension in ['.data-00000-of-00001', '.index', '.meta']:
- model_path = model_dir + '/' + filename
- break
- if model_path == None:
- raise Exception('The ckpt file is not found')
-
- reader = pywrap_tensorflow.NewCheckpointReader(model_path)
- var_to_shape_map = reader.get_variable_to_shape_map()
-
- parameters_dict = {}
- for key in sorted(var_to_shape_map):
- parameters_dict[key] = reader.get_tensor(key)
- np.savez(save_name, **parameters_dict)
- parameters_dict = None
- del parameters_dict
- logging.info("[*] Ckpt weights saved in npz_dict %s" % save_name)
|