Make sure deserialization, serialization, and get all grab the correct
Initializers in V2. PiperOrigin-RevId: 236008740
This commit is contained in:
parent
9650c977e3
commit
4728deaf57
@ -20,9 +20,11 @@ from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
|
||||
# These imports are brought in so that keras.initializers.deserialize
|
||||
# has them available in module_objects.
|
||||
@ -160,9 +162,20 @@ def serialize(initializer):
|
||||
|
||||
@keras_export('keras.initializers.deserialize')
|
||||
def deserialize(config, custom_objects=None):
|
||||
"""Return an `Initializer` object from its config."""
|
||||
if tf2.enabled():
|
||||
# Class names are the same for V1 and V2 but the V2 classes
|
||||
# are aliased in this file so we need to grab them directly
|
||||
# from `init_ops_v2`.
|
||||
module_objects = {
|
||||
obj_name: getattr(init_ops_v2, obj_name)
|
||||
for obj_name in dir(init_ops_v2)
|
||||
}
|
||||
else:
|
||||
module_objects = globals()
|
||||
return deserialize_keras_object(
|
||||
config,
|
||||
module_objects=globals(),
|
||||
module_objects=module_objects,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='initializer')
|
||||
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -203,6 +204,15 @@ class KerasInitializersTest(test.TestCase):
|
||||
self.assertEqual(tn.mean, 0.0)
|
||||
self.assertEqual(tn.stddev, 0.05)
|
||||
|
||||
def test_initializer_v2_get(self):
|
||||
tf2_force_enabled = tf2._force_enable # pylint: disable=protected-access
|
||||
try:
|
||||
tf2.enable()
|
||||
rn = keras.initializers.get('random_normal')
|
||||
self.assertIn('init_ops_v2', rn.__class__.__module__)
|
||||
finally:
|
||||
tf2._force_enable = tf2_force_enabled # pylint: disable=protected-access
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -34,8 +35,12 @@ class LayerSerializationTest(test.TestCase):
|
||||
self.assertEqual(new_layer.activation, keras.activations.relu)
|
||||
self.assertEqual(new_layer.bias_regularizer.__class__,
|
||||
keras.regularizers.L1L2)
|
||||
self.assertEqual(new_layer.kernel_initializer.__class__,
|
||||
keras.initializers.Ones)
|
||||
if tf2.enabled():
|
||||
self.assertEqual(new_layer.kernel_initializer.__class__,
|
||||
keras.initializers.OnesV2)
|
||||
else:
|
||||
self.assertEqual(new_layer.kernel_initializer.__class__,
|
||||
keras.initializers.Ones)
|
||||
self.assertEqual(new_layer.units, 3)
|
||||
|
||||
def test_serialize_deserialize_batchnorm(self):
|
||||
@ -45,8 +50,12 @@ class LayerSerializationTest(test.TestCase):
|
||||
self.assertEqual(config['class_name'], 'BatchNormalization')
|
||||
new_layer = keras.layers.deserialize(config)
|
||||
self.assertEqual(new_layer.momentum, 0.9)
|
||||
self.assertEqual(new_layer.beta_initializer.__class__,
|
||||
keras.initializers.Zeros)
|
||||
if tf2.enabled():
|
||||
self.assertEqual(new_layer.beta_initializer.__class__,
|
||||
keras.initializers.ZerosV2)
|
||||
else:
|
||||
self.assertEqual(new_layer.beta_initializer.__class__,
|
||||
keras.initializers.Zeros)
|
||||
self.assertEqual(new_layer.gamma_regularizer.__class__,
|
||||
keras.regularizers.L1L2)
|
||||
|
||||
|
@ -655,13 +655,14 @@ class RNNTest(test.TestCase):
|
||||
save.restore(sess, save_path)
|
||||
self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias))
|
||||
|
||||
# TODO(scottzhu): Look into updating for V2 Intializers.
|
||||
@test_util.run_deprecated_v1
|
||||
def testRNNCellSerialization(self):
|
||||
for cell in [
|
||||
rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True),
|
||||
rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32),
|
||||
rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32),
|
||||
rnn_cell_impl.GRUCell(
|
||||
32, kernel_initializer="ones", dtype=dtypes.float32)
|
||||
rnn_cell_impl.GRUCell(32, dtype=dtypes.float32)
|
||||
]:
|
||||
with self.cached_session():
|
||||
x = keras.Input((None, 5))
|
||||
|
@ -804,3 +804,17 @@ class _RandomGenerator(object):
|
||||
op = random_ops.truncated_normal
|
||||
return op(
|
||||
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=self.seed)
|
||||
|
||||
# Compatibility aliases
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
zero = zeros = Zeros
|
||||
one = ones = Ones
|
||||
constant = Constant
|
||||
uniform = random_uniform = RandomUniform
|
||||
normal = random_normal = RandomNormal
|
||||
truncated_normal = TruncatedNormal
|
||||
identity = Identity
|
||||
orthogonal = Orthogonal
|
||||
glorot_normal = GlorotNormal
|
||||
glorot_uniform = GlorotUniform
|
||||
|
Loading…
Reference in New Issue
Block a user