From 6766a8ee3e74ce511c271f070c57ad3846a33a17 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Fri, 27 Mar 2020 10:20:15 -0700 Subject: [PATCH] Make Keras deserialization aware of enable/disable_v2_behavior switches. PiperOrigin-RevId: 303355879 Change-Id: I3b8f7ae7db4ed1185585c6d56ed445cd4c61b6ca --- tensorflow/python/keras/engine/base_layer.py | 4 +- .../python/keras/engine/base_layer_v1.py | 2 +- .../python/keras/initializers/__init__.py | 186 ++++++++++-------- tensorflow/python/keras/initializers_test.py | 8 +- .../keras/layers/convolutional_recurrent.py | 2 +- tensorflow/python/keras/layers/recurrent.py | 2 +- .../python/keras/layers/serialization.py | 175 ++++++++++------ .../python/keras/layers/serialization_test.py | 6 - .../python/keras/utils/generic_utils.py | 10 +- 9 files changed, 235 insertions(+), 160 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 62a94afe51b..a6b168e8ceb 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -538,11 +538,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector): if initializer is None: # If dtype is DT_FLOAT, provide a uniform unit scaling initializer if dtype.is_floating: - initializer = initializers.glorot_uniform() + initializer = initializers.get('glorot_uniform') # If dtype is DT_INT/DT_UINT, provide a default value `zero` # If dtype is DT_BOOL, provide a default value `FALSE` elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: - initializer = initializers.zeros() + initializer = initializers.get('zeros') # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? else: raise ValueError('An initializer for variable %s of type %s is required' diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 1cf2450edf3..7b4ce8ad54c 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -402,7 +402,7 @@ class Layer(base_layer.Layer): if initializer is None: # If dtype is DT_FLOAT, provide a uniform unit scaling initializer if dtype.is_floating: - initializer = initializers.glorot_uniform() + initializer = initializers.get('glorot_uniform') # If dtype is DT_INT/DT_UINT, provide a default value `zero` # If dtype is DT_BOOL, provide a default value `FALSE` elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: diff --git a/tensorflow/python/keras/initializers/__init__.py b/tensorflow/python/keras/initializers/__init__.py index 7f2922b4c04..828a5b9ca49 100644 --- a/tensorflow/python/keras/initializers/__init__.py +++ b/tensorflow/python/keras/initializers/__init__.py @@ -14,114 +14,134 @@ # ============================================================================== """Keras initializer serialization / deserialization. """ -# pylint: disable=unused-import -# pylint: disable=line-too-long -# pylint: disable=g-import-not-at-top -# pylint: disable=g-bad-import-order -# pylint: disable=invalid-name from __future__ import absolute_import from __future__ import division from __future__ import print_function +import threading import six from tensorflow.python import tf2 -from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.initializers import initializers_v1 +from tensorflow.python.keras.initializers import initializers_v2 +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.ops import init_ops +from tensorflow.python.util import tf_inspect as inspect from tensorflow.python.util.tf_export import keras_export -# These imports are brought in so that keras.initializers.deserialize -# has them available in module_objects. -from tensorflow.python.keras.initializers.initializers_v2 import Constant as ConstantV2 -from tensorflow.python.keras.initializers.initializers_v2 import GlorotNormal as GlorotNormalV2 -from tensorflow.python.keras.initializers.initializers_v2 import GlorotUniform as GlorotUniformV2 -from tensorflow.python.keras.initializers.initializers_v2 import HeNormal as HeNormalV2 -from tensorflow.python.keras.initializers.initializers_v2 import HeUniform as HeUniformV2 -from tensorflow.python.keras.initializers.initializers_v2 import Identity as IdentityV2 -from tensorflow.python.keras.initializers.initializers_v2 import Initializer -from tensorflow.python.keras.initializers.initializers_v2 import LecunNormal as LecunNormalV2 -from tensorflow.python.keras.initializers.initializers_v2 import LecunUniform as LecunUniformV2 -from tensorflow.python.keras.initializers.initializers_v2 import Ones as OnesV2 -from tensorflow.python.keras.initializers.initializers_v2 import Orthogonal as OrthogonalV2 -from tensorflow.python.keras.initializers.initializers_v2 import RandomNormal as RandomNormalV2 -from tensorflow.python.keras.initializers.initializers_v2 import RandomUniform as RandomUniformV2 -from tensorflow.python.keras.initializers.initializers_v2 import TruncatedNormal as TruncatedNormalV2 -from tensorflow.python.keras.initializers.initializers_v2 import VarianceScaling as VarianceScalingV2 -from tensorflow.python.keras.initializers.initializers_v2 import Zeros as ZerosV2 -if tf2.enabled(): - Constant = ConstantV2 - GlorotNormal = GlorotNormalV2 - GlorotUniform = GlorotUniformV2 - HeNormal = HeNormalV2 - HeUniform = HeUniformV2 - Identity = IdentityV2 - LecunNormal = LecunNormalV2 - LecunUniform = LecunUniformV2 - Ones = OnesV2 - Orthogonal = OrthogonalV2 - RandomNormal = RandomNormalV2 - RandomUniform = RandomUniformV2 - TruncatedNormal = TruncatedNormalV2 - VarianceScaling = VarianceScalingV2 - Zeros = ZerosV2 -else: - from tensorflow.python.ops.init_ops import Constant - from tensorflow.python.ops.init_ops import GlorotNormal - from tensorflow.python.ops.init_ops import GlorotUniform - from tensorflow.python.ops.init_ops import Identity - from tensorflow.python.ops.init_ops import Ones - from tensorflow.python.ops.init_ops import Orthogonal - from tensorflow.python.ops.init_ops import VarianceScaling - from tensorflow.python.ops.init_ops import Zeros - from tensorflow.python.keras.initializers.initializers_v1 import HeNormal - from tensorflow.python.keras.initializers.initializers_v1 import HeUniform - from tensorflow.python.keras.initializers.initializers_v1 import LecunNormal - from tensorflow.python.keras.initializers.initializers_v1 import LecunUniform - from tensorflow.python.keras.initializers.initializers_v1 import RandomNormal - from tensorflow.python.keras.initializers.initializers_v1 import RandomUniform - from tensorflow.python.keras.initializers.initializers_v1 import TruncatedNormal +# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it +# thread-local to avoid concurrent mutations. +LOCAL = threading.local() -# Compatibility aliases -glorot_normal = GlorotNormal -glorot_uniform = GlorotUniform -he_normal = HeNormal -he_uniform = HeUniform -lecun_normal = LecunNormal -lecun_uniform = LecunUniform -zero = zeros = Zeros -one = ones = Ones -constant = Constant -uniform = random_uniform = RandomUniform -normal = random_normal = RandomNormal -truncated_normal = TruncatedNormal -identity = Identity -orthogonal = Orthogonal +def populate_deserializable_objects(): + """Populates dict ALL_OBJECTS with every built-in initializer. + """ + global LOCAL + if not hasattr(LOCAL, 'ALL_OBJECTS'): + LOCAL.ALL_OBJECTS = {} + LOCAL.GENERATED_WITH_V2 = None -# For unit tests -glorot_normalV2 = GlorotNormalV2 -glorot_uniformV2 = GlorotUniformV2 -he_normalV2 = HeNormalV2 -he_uniformV2 = HeUniformV2 -lecun_normalV2 = LecunNormalV2 -lecun_uniformV2 = LecunUniformV2 + if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled(): + # Objects dict is already generated for the proper TF version: + # do nothing. + return + + LOCAL.ALL_OBJECTS = {} + LOCAL.GENERATED_WITH_V2 = tf2.enabled() + + # Compatibility aliases (need to exist in both V1 and V2). + LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant + LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal + LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform + LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal + LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform + LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity + LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal + LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform + LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones + LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal + LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal + LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform + LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal + LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling + LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros + + # Out of an abundance of caution we also include these aliases that have + # a non-zero probability of having been included in saved configs in the past. + LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal + LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform + LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal + LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform + LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal + LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform + + if tf2.enabled(): + # For V2, entries are generated automatically based on the content of + # initializers_v2.py. + v2_objs = {} + base_cls = initializers_v2.Initializer + generic_utils.populate_dict_with_module_objects( + v2_objs, + [initializers_v2], + obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) + for key, value in v2_objs.items(): + LOCAL.ALL_OBJECTS[key] = value + # Functional aliases. + LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value + else: + # V1 initializers. + v1_objs = { + 'Constant': init_ops.Constant, + 'GlorotNormal': init_ops.GlorotNormal, + 'GlorotUniform': init_ops.GlorotUniform, + 'Identity': init_ops.Identity, + 'Ones': init_ops.Ones, + 'Orthogonal': init_ops.Orthogonal, + 'VarianceScaling': init_ops.VarianceScaling, + 'Zeros': init_ops.Zeros, + 'HeNormal': initializers_v1.HeNormal, + 'HeUniform': initializers_v1.HeUniform, + 'LecunNormal': initializers_v1.LecunNormal, + 'LecunUniform': initializers_v1.LecunUniform, + 'RandomNormal': initializers_v1.RandomNormal, + 'RandomUniform': initializers_v1.RandomUniform, + 'TruncatedNormal': initializers_v1.TruncatedNormal, + } + for key, value in v1_objs.items(): + LOCAL.ALL_OBJECTS[key] = value + # Functional aliases. + LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value + + # More compatibility aliases. + LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal'] + LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform'] + LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones'] + LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros'] + + +# For backwards compatibility, we populate this file with the objects +# from ALL_OBJECTS. We make no guarantees as to whether these objects will +# using their correct version. +populate_deserializable_objects() +globals().update(LOCAL.ALL_OBJECTS) # Utility functions @keras_export('keras.initializers.serialize') def serialize(initializer): - return serialize_keras_object(initializer) + return generic_utils.serialize_keras_object(initializer) @keras_export('keras.initializers.deserialize') def deserialize(config, custom_objects=None): """Return an `Initializer` object from its config.""" - module_objects = globals() - return deserialize_keras_object( + populate_deserializable_objects() + return generic_utils.deserialize_keras_object( config, - module_objects=module_objects, + module_objects=LOCAL.ALL_OBJECTS, custom_objects=custom_objects, printable_module_name='initializer') diff --git a/tensorflow/python/keras/initializers_test.py b/tensorflow/python/keras/initializers_test.py index 1170432d8e1..3e4502f14fc 100644 --- a/tensorflow/python/keras/initializers_test.py +++ b/tensorflow/python/keras/initializers_test.py @@ -91,7 +91,7 @@ class KerasInitializersTest(test.TestCase): fan_in, _ = init_ops._compute_fans(tensor_shape) std = np.sqrt(1. / fan_in) self._runner( - initializers.lecun_uniformV2(seed=123), + initializers.LecunUniformV2(seed=123), tensor_shape, target_mean=0., target_std=std) @@ -113,7 +113,7 @@ class KerasInitializersTest(test.TestCase): fan_in, _ = init_ops._compute_fans(tensor_shape) std = np.sqrt(2. / fan_in) self._runner( - initializers.he_uniformV2(seed=123), + initializers.HeUniformV2(seed=123), tensor_shape, target_mean=0., target_std=std) @@ -124,7 +124,7 @@ class KerasInitializersTest(test.TestCase): fan_in, _ = init_ops._compute_fans(tensor_shape) std = np.sqrt(1. / fan_in) self._runner( - initializers.lecun_normalV2(seed=123), + initializers.LecunNormalV2(seed=123), tensor_shape, target_mean=0., target_std=std) @@ -146,7 +146,7 @@ class KerasInitializersTest(test.TestCase): fan_in, _ = init_ops._compute_fans(tensor_shape) std = np.sqrt(2. / fan_in) self._runner( - initializers.he_normalV2(seed=123), + initializers.HeNormalV2(seed=123), tensor_shape, target_mean=0., target_std=std) diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py index 48f724b55e1..1929b145561 100644 --- a/tensorflow/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent.py @@ -569,7 +569,7 @@ class ConvLSTM2DCell(DropoutRNNCellMixin, Layer): def bias_initializer(_, *args, **kwargs): return K.concatenate([ self.bias_initializer((self.filters,), *args, **kwargs), - initializers.Ones()((self.filters,), *args, **kwargs), + initializers.get('ones')((self.filters,), *args, **kwargs), self.bias_initializer((self.filters * 2,), *args, **kwargs), ]) else: diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index ec635590e8b..628ecc332c5 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -2340,7 +2340,7 @@ class LSTMCell(DropoutRNNCellMixin, Layer): def bias_initializer(_, *args, **kwargs): return K.concatenate([ self.bias_initializer((self.units,), *args, **kwargs), - initializers.Ones()((self.units,), *args, **kwargs), + initializers.get('ones')((self.units,), *args, **kwargs), self.bias_initializer((self.units * 2,), *args, **kwargs), ]) else: diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index afefcc3f040..a73f7744d0e 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -21,51 +21,125 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import threading + from tensorflow.python import tf2 -from tensorflow.python.keras.engine.base_layer import AddLoss -from tensorflow.python.keras.engine.base_layer import AddMetric -from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer -from tensorflow.python.keras.engine.input_layer import Input -from tensorflow.python.keras.engine.input_layer import InputLayer -from tensorflow.python.keras.layers.advanced_activations import * -from tensorflow.python.keras.layers.convolutional import * -from tensorflow.python.keras.layers.convolutional_recurrent import * -from tensorflow.python.keras.layers.core import * -from tensorflow.python.keras.layers.cudnn_recurrent import * -from tensorflow.python.keras.layers.dense_attention import * -from tensorflow.python.keras.layers.embeddings import * -from tensorflow.python.keras.layers.local import * -from tensorflow.python.keras.layers.merge import * -from tensorflow.python.keras.layers.noise import * -from tensorflow.python.keras.layers.normalization import * -from tensorflow.python.keras.layers.pooling import * -from tensorflow.python.keras.layers.preprocessing.image_preprocessing import * -from tensorflow.python.keras.layers.preprocessing.normalization_v1 import * -from tensorflow.python.keras.layers.recurrent import * -from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import * -from tensorflow.python.keras.layers.wrappers import * -from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import input_layer +from tensorflow.python.keras.layers import advanced_activations +from tensorflow.python.keras.layers import convolutional +from tensorflow.python.keras.layers import convolutional_recurrent +from tensorflow.python.keras.layers import core +from tensorflow.python.keras.layers import cudnn_recurrent +from tensorflow.python.keras.layers import dense_attention +from tensorflow.python.keras.layers import embeddings +from tensorflow.python.keras.layers import local +from tensorflow.python.keras.layers import merge +from tensorflow.python.keras.layers import noise +from tensorflow.python.keras.layers import normalization +from tensorflow.python.keras.layers import normalization_v2 +from tensorflow.python.keras.layers import pooling +from tensorflow.python.keras.layers import recurrent +from tensorflow.python.keras.layers import recurrent_v2 +from tensorflow.python.keras.layers import rnn_cell_wrapper_v2 +from tensorflow.python.keras.layers import wrappers +from tensorflow.python.keras.layers.preprocessing import image_preprocessing +from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization +from tensorflow.python.keras.layers.preprocessing import normalization_v1 as preprocessing_normalization_v1 +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.util import tf_inspect as inspect from tensorflow.python.util.tf_export import keras_export -if tf2.enabled(): - from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top - from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top - from tensorflow.python.keras.layers.preprocessing.normalization import * # pylint: disable=g-import-not-at-top -# This deserialization table is added for backward compatibility, as in TF 1.13, -# BatchNormalizationV1 and BatchNormalizationV2 are used as class name for v1 -# and v2 version of BatchNormalization, respectively. Here we explicitly convert -# them to the canonical name in the config of deserialization. -_DESERIALIZATION_TABLE = { - 'BatchNormalizationV1': 'BatchNormalization', - 'BatchNormalizationV2': 'BatchNormalization', -} +ALL_MODULES = ( + base_layer, + input_layer, + advanced_activations, + convolutional, + convolutional_recurrent, + core, + cudnn_recurrent, + dense_attention, + embeddings, + local, + merge, + noise, + normalization, + pooling, + image_preprocessing, + preprocessing_normalization_v1, + recurrent, + wrappers +) +ALL_V2_MODULES = ( + rnn_cell_wrapper_v2, + normalization_v2, + recurrent_v2, + preprocessing_normalization +) + +# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it +# thread-local to avoid concurrent mutations. +LOCAL = threading.local() + + +def populate_deserializable_objects(): + """Populates dict ALL_OBJECTS with every built-in layer. + """ + global LOCAL + if not hasattr(LOCAL, 'ALL_OBJECTS'): + LOCAL.ALL_OBJECTS = {} + LOCAL.GENERATED_WITH_V2 = None + + if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled(): + # Objects dict is already generated for the proper TF version: + # do nothing. + return + + LOCAL.ALL_OBJECTS = {} + LOCAL.GENERATED_WITH_V2 = tf2.enabled() + + base_cls = base_layer.Layer + generic_utils.populate_dict_with_module_objects( + LOCAL.ALL_OBJECTS, + ALL_MODULES, + obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) + + # Overwrite certain V1 objects with V2 versions + if tf2.enabled(): + generic_utils.populate_dict_with_module_objects( + LOCAL.ALL_OBJECTS, + ALL_V2_MODULES, + obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) + + # These deserialization aliases are added for backward compatibility, + # as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2" + # were used as class name for v1 and v2 version of BatchNormalization, + # respectively. Here we explicitly convert them to their canonical names. + LOCAL.ALL_OBJECTS['BatchNormalizationV1'] = normalization.BatchNormalization + LOCAL.ALL_OBJECTS[ + 'BatchNormalizationV2'] = normalization_v2.BatchNormalization + + # Prevent circular dependencies. + from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top + from tensorflow.python.feature_column import dense_features # pylint: disable=g-import-not-at-top + from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top + + LOCAL.ALL_OBJECTS['Input'] = input_layer.Input + LOCAL.ALL_OBJECTS['Network'] = models.Network + LOCAL.ALL_OBJECTS['Model'] = models.Model + LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential + LOCAL.ALL_OBJECTS['LinearModel'] = LinearModel + LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel + LOCAL.ALL_OBJECTS['DenseFeatures'] = dense_features.DenseFeatures + LOCAL.ALL_OBJECTS['SequenceFeatures'] = sfc.SequenceFeatures @keras_export('keras.layers.serialize') def serialize(layer): - return serialize_keras_object(layer) + return generic_utils.serialize_keras_object(layer) @keras_export('keras.layers.deserialize') @@ -80,30 +154,9 @@ def deserialize(config, custom_objects=None): Returns: Layer instance (may be Model, Sequential, Network, Layer...) """ - # Prevent circular dependencies. - from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top - from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top - from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top - from tensorflow.python.feature_column import dense_features # pylint: disable=g-import-not-at-top - from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top - - globs = globals() # All layers. - globs['Network'] = models.Network - globs['Model'] = models.Model - globs['Sequential'] = models.Sequential - globs['LinearModel'] = LinearModel - globs['WideDeepModel'] = WideDeepModel - - # Prevent circular dependencies with FeatureColumn serialization. - globs['DenseFeatures'] = dense_features.DenseFeatures - globs['SequenceFeatures'] = sfc.SequenceFeatures - - layer_class_name = config['class_name'] - if layer_class_name in _DESERIALIZATION_TABLE: - config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name] - - return deserialize_keras_object( + populate_deserializable_objects() + return generic_utils.deserialize_keras_object( config, - module_objects=globs, + module_objects=LOCAL.ALL_OBJECTS, custom_objects=custom_objects, printable_module_name='layer') diff --git a/tensorflow/python/keras/layers/serialization_test.py b/tensorflow/python/keras/layers/serialization_test.py index 5c23937ddb4..cd88b072224 100644 --- a/tensorflow/python/keras/layers/serialization_test.py +++ b/tensorflow/python/keras/layers/serialization_test.py @@ -124,12 +124,6 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase): layer = batchnorm_layer( momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2') config = keras.layers.serialize(layer) - # To simulate if BatchNormalizationV1 or BatchNormalizationV2 appears in the - # saved model. - if batchnorm_layer is batchnorm_v1.BatchNormalization: - config['class_name'] = 'BatchNormalizationV1' - else: - config['class_name'] = 'BatchNormalizationV2' new_layer = keras.layers.deserialize(config) self.assertEqual(new_layer.momentum, 0.9) if tf2.enabled(): diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index 5331806fc23..897f97c793b 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -364,7 +364,8 @@ def deserialize_keras_object(identifier, else: obj = module_objects.get(object_name) if obj is None: - raise ValueError('Unknown ' + printable_module_name + ':' + object_name) + raise ValueError( + 'Unknown ' + printable_module_name + ': ' + object_name) # Classes passed by name are instantiated with no args, functions are # returned as-is. if tf_inspect.isclass(obj): @@ -783,6 +784,13 @@ def is_default(method): return getattr(method, '_is_default', False) +def populate_dict_with_module_objects(target_dict, modules, obj_filter): + for module in modules: + for name in dir(module): + obj = getattr(module, name) + if obj_filter(obj): + target_dict[name] = obj + # Aliases