233 lines
7.3 KiB
Python
233 lines
7.3 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for Keras initializers."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.keras import backend
|
|
from tensorflow.python.keras import combinations
|
|
from tensorflow.python.keras import initializers
|
|
from tensorflow.python.keras import models
|
|
from tensorflow.python.keras.engine import input_layer
|
|
from tensorflow.python.keras.layers import core
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import init_ops
|
|
from tensorflow.python.platform import test
|
|
|
|
|
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
|
class KerasInitializersTest(test.TestCase):
|
|
|
|
def _runner(self, init, shape, target_mean=None, target_std=None,
|
|
target_max=None, target_min=None):
|
|
variable = backend.variable(init(shape))
|
|
output = backend.get_value(variable)
|
|
# Test serialization (assumes deterministic behavior).
|
|
config = init.get_config()
|
|
reconstructed_init = init.__class__.from_config(config)
|
|
variable = backend.variable(reconstructed_init(shape))
|
|
output_2 = backend.get_value(variable)
|
|
self.assertAllClose(output, output_2, atol=1e-4)
|
|
|
|
def test_uniform(self):
|
|
tensor_shape = (9, 6, 7)
|
|
with self.cached_session():
|
|
self._runner(
|
|
initializers.RandomUniformV2(minval=-1, maxval=1, seed=124),
|
|
tensor_shape,
|
|
target_mean=0.,
|
|
target_max=1,
|
|
target_min=-1)
|
|
|
|
def test_normal(self):
|
|
tensor_shape = (8, 12, 99)
|
|
with self.cached_session():
|
|
self._runner(
|
|
initializers.RandomNormalV2(mean=0, stddev=1, seed=153),
|
|
tensor_shape,
|
|
target_mean=0.,
|
|
target_std=1)
|
|
|
|
def test_truncated_normal(self):
|
|
tensor_shape = (12, 99, 7)
|
|
with self.cached_session():
|
|
self._runner(
|
|
initializers.TruncatedNormalV2(mean=0, stddev=1, seed=126),
|
|
tensor_shape,
|
|
target_mean=0.,
|
|
target_max=2,
|
|
target_min=-2)
|
|
|
|
def test_constant(self):
|
|
tensor_shape = (5, 6, 4)
|
|
with self.cached_session():
|
|
self._runner(
|
|
initializers.ConstantV2(2.),
|
|
tensor_shape,
|
|
target_mean=2,
|
|
target_max=2,
|
|
target_min=2)
|
|
|
|
def test_lecun_uniform(self):
|
|
tensor_shape = (5, 6, 4, 2)
|
|
with self.cached_session():
|
|
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
|
std = np.sqrt(1. / fan_in)
|
|
self._runner(
|
|
initializers.LecunUniformV2(seed=123),
|
|
tensor_shape,
|
|
target_mean=0.,
|
|
target_std=std)
|
|
|
|
def test_glorot_uniform(self):
|
|
tensor_shape = (5, 6, 4, 2)
|
|
with self.cached_session():
|
|
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
|
|
std = np.sqrt(2. / (fan_in + fan_out))
|
|
self._runner(
|
|
initializers.GlorotUniformV2(seed=123),
|
|
tensor_shape,
|
|
target_mean=0.,
|
|
target_std=std)
|
|
|
|
def test_he_uniform(self):
|
|
tensor_shape = (5, 6, 4, 2)
|
|
with self.cached_session():
|
|
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
|
std = np.sqrt(2. / fan_in)
|
|
self._runner(
|
|
initializers.HeUniformV2(seed=123),
|
|
tensor_shape,
|
|
target_mean=0.,
|
|
target_std=std)
|
|
|
|
def test_lecun_normal(self):
|
|
tensor_shape = (5, 6, 4, 2)
|
|
with self.cached_session():
|
|
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
|
std = np.sqrt(1. / fan_in)
|
|
self._runner(
|
|
initializers.LecunNormalV2(seed=123),
|
|
tensor_shape,
|
|
target_mean=0.,
|
|
target_std=std)
|
|
|
|
def test_glorot_normal(self):
|
|
tensor_shape = (5, 6, 4, 2)
|
|
with self.cached_session():
|
|
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
|
|
std = np.sqrt(2. / (fan_in + fan_out))
|
|
self._runner(
|
|
initializers.GlorotNormalV2(seed=123),
|
|
tensor_shape,
|
|
target_mean=0.,
|
|
target_std=std)
|
|
|
|
def test_he_normal(self):
|
|
tensor_shape = (5, 6, 4, 2)
|
|
with self.cached_session():
|
|
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
|
std = np.sqrt(2. / fan_in)
|
|
self._runner(
|
|
initializers.HeNormalV2(seed=123),
|
|
tensor_shape,
|
|
target_mean=0.,
|
|
target_std=std)
|
|
|
|
def test_orthogonal(self):
|
|
tensor_shape = (20, 20)
|
|
with self.cached_session():
|
|
self._runner(
|
|
initializers.OrthogonalV2(seed=123), tensor_shape, target_mean=0.)
|
|
|
|
def test_identity(self):
|
|
with self.cached_session():
|
|
tensor_shape = (3, 4, 5)
|
|
with self.assertRaises(ValueError):
|
|
self._runner(
|
|
initializers.IdentityV2(),
|
|
tensor_shape,
|
|
target_mean=1. / tensor_shape[0],
|
|
target_max=1.)
|
|
|
|
tensor_shape = (3, 3)
|
|
self._runner(
|
|
initializers.IdentityV2(),
|
|
tensor_shape,
|
|
target_mean=1. / tensor_shape[0],
|
|
target_max=1.)
|
|
|
|
def test_zero(self):
|
|
tensor_shape = (4, 5)
|
|
with self.cached_session():
|
|
self._runner(
|
|
initializers.ZerosV2(), tensor_shape, target_mean=0., target_max=0.)
|
|
|
|
def test_one(self):
|
|
tensor_shape = (4, 5)
|
|
with self.cached_session():
|
|
self._runner(
|
|
initializers.OnesV2(), tensor_shape, target_mean=1., target_max=1.)
|
|
|
|
def test_default_random_uniform(self):
|
|
ru = initializers.get('uniform')
|
|
self.assertEqual(ru.minval, -0.05)
|
|
self.assertEqual(ru.maxval, 0.05)
|
|
|
|
def test_default_random_normal(self):
|
|
rn = initializers.get('normal')
|
|
self.assertEqual(rn.mean, 0.0)
|
|
self.assertEqual(rn.stddev, 0.05)
|
|
|
|
def test_default_truncated_normal(self):
|
|
tn = initializers.get('truncated_normal')
|
|
self.assertEqual(tn.mean, 0.0)
|
|
self.assertEqual(tn.stddev, 0.05)
|
|
|
|
def test_custom_initializer_saving(self):
|
|
|
|
def my_initializer(shape, dtype=None):
|
|
return array_ops.ones(shape, dtype=dtype)
|
|
|
|
inputs = input_layer.Input((10,))
|
|
outputs = core.Dense(1, kernel_initializer=my_initializer)(inputs)
|
|
model = models.Model(inputs, outputs)
|
|
model2 = model.from_config(
|
|
model.get_config(), custom_objects={'my_initializer': my_initializer})
|
|
self.assertEqual(model2.layers[1].kernel_initializer, my_initializer)
|
|
|
|
@test_util.run_v2_only
|
|
def test_load_external_variance_scaling_v2(self):
|
|
external_serialized_json = {
|
|
'class_name': 'VarianceScaling',
|
|
'config': {
|
|
'distribution': 'normal',
|
|
'mode': 'fan_avg',
|
|
'scale': 1.0,
|
|
'seed': None
|
|
}
|
|
}
|
|
initializer = initializers.deserialize(external_serialized_json)
|
|
self.assertEqual(initializer.distribution, 'truncated_normal')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test.main()
|