|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- import os
- import unittest
-
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
-
- import numpy as np
- import tensorflow as tf
- import tensorlayer as tl
- from tensorlayer.layers import *
- from tensorlayer.models import *
-
- from tests.utils import CustomTestCase
-
-
- def basic_static_model(include_top=True):
- ni = Input((None, 24, 24, 3))
- nn = Conv2d(16, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, name="conv1")(ni)
- nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')(nn)
-
- nn = Conv2d(16, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, name="conv2")(nn)
- nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')(nn)
-
- nn = Flatten(name='flatten')(nn)
- nn = Dense(100, act=None, name="dense1")(nn)
- if include_top is True:
- nn = Dense(10, act=None, name="dense2")(nn)
- M = Model(inputs=ni, outputs=nn)
- return M
-
-
- class basic_dynamic_model(Model):
-
- def __init__(self, include_top=True):
- super(basic_dynamic_model, self).__init__()
- self.include_top = include_top
- self.conv1 = Conv2d(16, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, in_channels=3, name="conv1")
- self.pool1 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')
-
- self.conv2 = Conv2d(16, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, in_channels=16, name="conv2")
- self.pool2 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')
-
- self.flatten = Flatten(name='flatten')
- self.dense1 = Dense(100, act=None, in_channels=576, name="dense1")
- if include_top is True:
- self.dense2 = Dense(10, act=None, in_channels=100, name="dense2")
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.pool1(x)
- x = self.conv2(x)
- x = self.pool2(x)
- x = self.flatten(x)
- x = self.dense1(x)
- if self.include_top:
- x = self.dense2(x)
- return x
-
-
- class Nested_VGG(Model):
-
- def __init__(self):
- super(Nested_VGG, self).__init__()
- self.vgg1 = tl.models.vgg16()
- self.vgg2 = tl.models.vgg16()
-
- def forward(self, x):
- pass
-
-
- class Model_Save_Test(CustomTestCase):
-
- @classmethod
- def setUpClass(cls):
- cls.static_basic = basic_static_model()
- cls.dynamic_basic = basic_dynamic_model()
- cls.static_basic_skip = basic_static_model(include_top=False)
- cls.dynamic_basic_skip = basic_dynamic_model(include_top=False)
-
- print([l.name for l in cls.dynamic_basic.all_layers])
- print([l.name for l in cls.dynamic_basic_skip.all_layers])
- pass
-
- @classmethod
- def tearDownClass(cls):
- pass
-
- def normal_save(self, model_basic):
- # Default save
- model_basic.save_weights('./model_basic.none')
-
- # hdf5
- print('testing hdf5 saving...')
- modify_val = np.zeros_like(model_basic.all_weights[-2].numpy())
- ori_val = model_basic.all_weights[-2].numpy()
- model_basic.save_weights("./model_basic.h5")
- model_basic.all_weights[-2].assign(modify_val)
- model_basic.load_weights("./model_basic.h5")
- self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
-
- model_basic.all_weights[-2].assign(modify_val)
- model_basic.load_weights("./model_basic.h5", format="hdf5")
- self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
-
- model_basic.all_weights[-2].assign(modify_val)
- model_basic.load_weights("./model_basic.h5", format="hdf5", in_order=False)
- self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
-
- # npz
- print('testing npz saving...')
- model_basic.save_weights("./model_basic.npz", format='npz')
- model_basic.all_weights[-2].assign(modify_val)
- model_basic.load_weights("./model_basic.npz")
-
- model_basic.all_weights[-2].assign(modify_val)
- model_basic.load_weights("./model_basic.npz", format='npz')
- model_basic.save_weights("./model_basic.npz")
- self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
-
- # npz_dict
- print('testing npz_dict saving...')
- model_basic.save_weights("./model_basic.npz", format='npz_dict')
- model_basic.all_weights[-2].assign(modify_val)
- model_basic.load_weights("./model_basic.npz", format='npz_dict')
- self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
-
- # ckpt
- try:
- model_basic.save_weights('./model_basic.ckpt', format='ckpt')
- except Exception as e:
- self.assertIsInstance(e, NotImplementedError)
-
- # other cases
- try:
- model_basic.save_weights('./model_basic.xyz', format='xyz')
- except Exception as e:
- self.assertIsInstance(e, ValueError)
- try:
- model_basic.load_weights('./model_basic.xyz', format='xyz')
- except Exception as e:
- self.assertIsInstance(e, FileNotFoundError)
- try:
- model_basic.load_weights('./model_basic.h5', format='xyz')
- except Exception as e:
- self.assertIsInstance(e, ValueError)
-
- def test_normal_save(self):
- print('-' * 20, 'test save weights', '-' * 20)
-
- self.normal_save(self.static_basic)
- self.normal_save(self.dynamic_basic)
-
- print('testing save dynamic and load static...')
- try:
- self.dynamic_basic.save_weights("./model_basic.h5")
- self.static_basic.load_weights("./model_basic.h5", in_order=False)
- except Exception as e:
- print(e)
-
- def test_skip(self):
- print('-' * 20, 'test skip save/load', '-' * 20)
-
- print("testing dynamic skip load...")
- self.dynamic_basic.save_weights("./model_basic.h5")
- ori_weights = self.dynamic_basic_skip.all_weights
- ori_val = ori_weights[1].numpy()
- modify_val = np.zeros_like(ori_val)
- self.dynamic_basic_skip.all_weights[1].assign(modify_val)
- self.dynamic_basic_skip.load_weights("./model_basic.h5", skip=True)
- self.assertLess(np.max(np.abs(ori_val - self.dynamic_basic_skip.all_weights[1].numpy())), 1e-7)
-
- try:
- self.dynamic_basic_skip.load_weights("./model_basic.h5", in_order=False, skip=False)
- except Exception as e:
- print(e)
-
- print("testing static skip load...")
- self.static_basic.save_weights("./model_basic.h5")
- ori_weights = self.static_basic_skip.all_weights
- ori_val = ori_weights[1].numpy()
- modify_val = np.zeros_like(ori_val)
- self.static_basic_skip.all_weights[1].assign(modify_val)
- self.static_basic_skip.load_weights("./model_basic.h5", skip=True)
- self.assertLess(np.max(np.abs(ori_val - self.static_basic_skip.all_weights[1].numpy())), 1e-7)
-
- try:
- self.static_basic_skip.load_weights("./model_basic.h5", in_order=False, skip=False)
- except Exception as e:
- print(e)
-
- def test_nested_vgg(self):
- print('-' * 20, 'test nested vgg', '-' * 20)
- nested_vgg = Nested_VGG()
- print([l.name for l in nested_vgg.all_layers])
- nested_vgg.save_weights("nested_vgg.h5")
-
- # modify vgg1 weight val
- tar_weight1 = nested_vgg.vgg1.layers[0].all_weights[0]
- print(tar_weight1.name)
- ori_val1 = tar_weight1.numpy()
- modify_val1 = np.zeros_like(ori_val1)
- tar_weight1.assign(modify_val1)
- # modify vgg2 weight val
- tar_weight2 = nested_vgg.vgg2.layers[1].all_weights[0]
- print(tar_weight2.name)
- ori_val2 = tar_weight2.numpy()
- modify_val2 = np.zeros_like(ori_val2)
- tar_weight2.assign(modify_val2)
-
- nested_vgg.load_weights("nested_vgg.h5")
-
- self.assertLess(np.max(np.abs(ori_val1 - tar_weight1.numpy())), 1e-7)
- self.assertLess(np.max(np.abs(ori_val2 - tar_weight2.numpy())), 1e-7)
-
- def test_double_nested_vgg(self):
- print('-' * 20, 'test_double_nested_vgg', '-' * 20)
-
- class mymodel(Model):
-
- def __init__(self):
- super(mymodel, self).__init__()
- self.inner = Nested_VGG()
- self.list = LayerList(
- [
- tl.layers.Dense(n_units=4, in_channels=10, name='dense1'),
- tl.layers.Dense(n_units=3, in_channels=4, name='dense2')
- ]
- )
-
- def forward(self, *inputs, **kwargs):
- pass
-
- net = mymodel()
- net.save_weights("double_nested.h5")
- print([x.name for x in net.all_layers])
-
- # modify vgg1 weight val
- tar_weight1 = net.inner.vgg1.layers[0].all_weights[0]
- ori_val1 = tar_weight1.numpy()
- modify_val1 = np.zeros_like(ori_val1)
- tar_weight1.assign(modify_val1)
- # modify vgg2 weight val
- tar_weight2 = net.inner.vgg2.layers[1].all_weights[0]
- ori_val2 = tar_weight2.numpy()
- modify_val2 = np.zeros_like(ori_val2)
- tar_weight2.assign(modify_val2)
-
- net.load_weights("double_nested.h5")
- self.assertLess(np.max(np.abs(ori_val1 - tar_weight1.numpy())), 1e-7)
- self.assertLess(np.max(np.abs(ori_val2 - tar_weight2.numpy())), 1e-7)
-
- def test_layerlist(self):
- print('-' * 20, 'test_layerlist', '-' * 20)
-
- # simple modellayer
- ni = tl.layers.Input([10, 4])
- nn = tl.layers.Dense(n_units=3, name='dense1')(ni)
- modellayer = tl.models.Model(inputs=ni, outputs=nn, name='modellayer').as_layer()
-
- # nested layerlist with modellayer
- inputs = tl.layers.Input([10, 5])
- layer1 = tl.layers.LayerList([tl.layers.Dense(n_units=4, name='dense1'), modellayer])(inputs)
- model = tl.models.Model(inputs=inputs, outputs=layer1, name='layerlistmodel')
-
- model.save_weights("layerlist.h5")
- tar_weight = model.get_layer(index=-1)[0].all_weights[0]
- print(tar_weight.name)
- ori_val = tar_weight.numpy()
- modify_val = np.zeros_like(ori_val)
- tar_weight.assign(modify_val)
-
- model.load_weights("layerlist.h5")
- self.assertLess(np.max(np.abs(ori_val - tar_weight.numpy())), 1e-7)
-
- def test_exceptions(self):
- print('-' * 20, 'test_exceptions', '-' * 20)
- try:
- ni = Input([4, 784])
- model = Model(inputs=ni, outputs=ni)
- model.save_weights('./empty_model.h5')
- except Exception as e:
- print(e)
-
-
- if __name__ == '__main__':
-
- unittest.main()
|