Make sure deserialization, serialization, and get all grab the correct

Initializers in V2.

PiperOrigin-RevId: 236008740
This commit is contained in:
Thomas O'Malley 2019-02-27 15:49:28 -08:00 committed by TensorFlower Gardener
parent 9650c977e3
commit 4728deaf57
5 changed files with 54 additions and 7 deletions

View File

@ -20,9 +20,11 @@ from __future__ import print_function
import six
from tensorflow.python import tf2
from tensorflow.python.framework import dtypes
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import init_ops_v2
# These imports are brought in so that keras.initializers.deserialize
# has them available in module_objects.
@ -160,9 +162,20 @@ def serialize(initializer):
@keras_export('keras.initializers.deserialize')
def deserialize(config, custom_objects=None):
"""Return an `Initializer` object from its config."""
if tf2.enabled():
# Class names are the same for V1 and V2 but the V2 classes
# are aliased in this file so we need to grab them directly
# from `init_ops_v2`.
module_objects = {
obj_name: getattr(init_ops_v2, obj_name)
for obj_name in dir(init_ops_v2)
}
else:
module_objects = globals()
return deserialize_keras_object(
config,
module_objects=globals(),
module_objects=module_objects,
custom_objects=custom_objects,
printable_module_name='initializer')

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.framework import test_util
from tensorflow.python.ops import init_ops
from tensorflow.python.platform import test
@ -203,6 +204,15 @@ class KerasInitializersTest(test.TestCase):
self.assertEqual(tn.mean, 0.0)
self.assertEqual(tn.stddev, 0.05)
def test_initializer_v2_get(self):
tf2_force_enabled = tf2._force_enable # pylint: disable=protected-access
try:
tf2.enable()
rn = keras.initializers.get('random_normal')
self.assertIn('init_ops_v2', rn.__class__.__module__)
finally:
tf2._force_enable = tf2_force_enabled # pylint: disable=protected-access
if __name__ == '__main__':
test.main()

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.platform import test
@ -34,8 +35,12 @@ class LayerSerializationTest(test.TestCase):
self.assertEqual(new_layer.activation, keras.activations.relu)
self.assertEqual(new_layer.bias_regularizer.__class__,
keras.regularizers.L1L2)
self.assertEqual(new_layer.kernel_initializer.__class__,
keras.initializers.Ones)
if tf2.enabled():
self.assertEqual(new_layer.kernel_initializer.__class__,
keras.initializers.OnesV2)
else:
self.assertEqual(new_layer.kernel_initializer.__class__,
keras.initializers.Ones)
self.assertEqual(new_layer.units, 3)
def test_serialize_deserialize_batchnorm(self):
@ -45,8 +50,12 @@ class LayerSerializationTest(test.TestCase):
self.assertEqual(config['class_name'], 'BatchNormalization')
new_layer = keras.layers.deserialize(config)
self.assertEqual(new_layer.momentum, 0.9)
self.assertEqual(new_layer.beta_initializer.__class__,
keras.initializers.Zeros)
if tf2.enabled():
self.assertEqual(new_layer.beta_initializer.__class__,
keras.initializers.ZerosV2)
else:
self.assertEqual(new_layer.beta_initializer.__class__,
keras.initializers.Zeros)
self.assertEqual(new_layer.gamma_regularizer.__class__,
keras.regularizers.L1L2)

View File

@ -655,13 +655,14 @@ class RNNTest(test.TestCase):
save.restore(sess, save_path)
self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias))
# TODO(scottzhu): Look into updating for V2 Intializers.
@test_util.run_deprecated_v1
def testRNNCellSerialization(self):
for cell in [
rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True),
rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32),
rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32),
rnn_cell_impl.GRUCell(
32, kernel_initializer="ones", dtype=dtypes.float32)
rnn_cell_impl.GRUCell(32, dtype=dtypes.float32)
]:
with self.cached_session():
x = keras.Input((None, 5))

View File

@ -804,3 +804,17 @@ class _RandomGenerator(object):
op = random_ops.truncated_normal
return op(
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=self.seed)
# Compatibility aliases
# pylint: disable=invalid-name
zero = zeros = Zeros
one = ones = Ones
constant = Constant
uniform = random_uniform = RandomUniform
normal = random_normal = RandomNormal
truncated_normal = TruncatedNormal
identity = Identity
orthogonal = Orthogonal
glorot_normal = GlorotNormal
glorot_uniform = GlorotUniform