Make Keras deserialization aware of enable/disable_v2_behavior switches.

PiperOrigin-RevId: 303355879
Change-Id: I3b8f7ae7db4ed1185585c6d56ed445cd4c61b6ca
This commit is contained in:
Francois Chollet 2020-03-27 10:20:15 -07:00 committed by TensorFlower Gardener
parent cd59b293b5
commit 6766a8ee3e
9 changed files with 235 additions and 160 deletions

View File

@ -538,11 +538,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
if initializer is None: if initializer is None:
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
if dtype.is_floating: if dtype.is_floating:
initializer = initializers.glorot_uniform() initializer = initializers.get('glorot_uniform')
# If dtype is DT_INT/DT_UINT, provide a default value `zero` # If dtype is DT_INT/DT_UINT, provide a default value `zero`
# If dtype is DT_BOOL, provide a default value `FALSE` # If dtype is DT_BOOL, provide a default value `FALSE`
elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
initializer = initializers.zeros() initializer = initializers.get('zeros')
# NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
else: else:
raise ValueError('An initializer for variable %s of type %s is required' raise ValueError('An initializer for variable %s of type %s is required'

View File

@ -402,7 +402,7 @@ class Layer(base_layer.Layer):
if initializer is None: if initializer is None:
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
if dtype.is_floating: if dtype.is_floating:
initializer = initializers.glorot_uniform() initializer = initializers.get('glorot_uniform')
# If dtype is DT_INT/DT_UINT, provide a default value `zero` # If dtype is DT_INT/DT_UINT, provide a default value `zero`
# If dtype is DT_BOOL, provide a default value `FALSE` # If dtype is DT_BOOL, provide a default value `FALSE`
elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:

View File

@ -14,114 +14,134 @@
# ============================================================================== # ==============================================================================
"""Keras initializer serialization / deserialization. """Keras initializer serialization / deserialization.
""" """
# pylint: disable=unused-import
# pylint: disable=line-too-long
# pylint: disable=g-import-not-at-top
# pylint: disable=g-bad-import-order
# pylint: disable=invalid-name
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import threading
import six import six
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.initializers import initializers_v1
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.keras.initializers import initializers_v2
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import init_ops
from tensorflow.python.util import tf_inspect as inspect
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
# These imports are brought in so that keras.initializers.deserialize
# has them available in module_objects.
from tensorflow.python.keras.initializers.initializers_v2 import Constant as ConstantV2
from tensorflow.python.keras.initializers.initializers_v2 import GlorotNormal as GlorotNormalV2
from tensorflow.python.keras.initializers.initializers_v2 import GlorotUniform as GlorotUniformV2
from tensorflow.python.keras.initializers.initializers_v2 import HeNormal as HeNormalV2
from tensorflow.python.keras.initializers.initializers_v2 import HeUniform as HeUniformV2
from tensorflow.python.keras.initializers.initializers_v2 import Identity as IdentityV2
from tensorflow.python.keras.initializers.initializers_v2 import Initializer
from tensorflow.python.keras.initializers.initializers_v2 import LecunNormal as LecunNormalV2
from tensorflow.python.keras.initializers.initializers_v2 import LecunUniform as LecunUniformV2
from tensorflow.python.keras.initializers.initializers_v2 import Ones as OnesV2
from tensorflow.python.keras.initializers.initializers_v2 import Orthogonal as OrthogonalV2
from tensorflow.python.keras.initializers.initializers_v2 import RandomNormal as RandomNormalV2
from tensorflow.python.keras.initializers.initializers_v2 import RandomUniform as RandomUniformV2
from tensorflow.python.keras.initializers.initializers_v2 import TruncatedNormal as TruncatedNormalV2
from tensorflow.python.keras.initializers.initializers_v2 import VarianceScaling as VarianceScalingV2
from tensorflow.python.keras.initializers.initializers_v2 import Zeros as ZerosV2
if tf2.enabled(): # LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
Constant = ConstantV2 # thread-local to avoid concurrent mutations.
GlorotNormal = GlorotNormalV2 LOCAL = threading.local()
GlorotUniform = GlorotUniformV2
HeNormal = HeNormalV2
HeUniform = HeUniformV2
Identity = IdentityV2
LecunNormal = LecunNormalV2
LecunUniform = LecunUniformV2
Ones = OnesV2
Orthogonal = OrthogonalV2
RandomNormal = RandomNormalV2
RandomUniform = RandomUniformV2
TruncatedNormal = TruncatedNormalV2
VarianceScaling = VarianceScalingV2
Zeros = ZerosV2
else:
from tensorflow.python.ops.init_ops import Constant
from tensorflow.python.ops.init_ops import GlorotNormal
from tensorflow.python.ops.init_ops import GlorotUniform
from tensorflow.python.ops.init_ops import Identity
from tensorflow.python.ops.init_ops import Ones
from tensorflow.python.ops.init_ops import Orthogonal
from tensorflow.python.ops.init_ops import VarianceScaling
from tensorflow.python.ops.init_ops import Zeros
from tensorflow.python.keras.initializers.initializers_v1 import HeNormal
from tensorflow.python.keras.initializers.initializers_v1 import HeUniform
from tensorflow.python.keras.initializers.initializers_v1 import LecunNormal
from tensorflow.python.keras.initializers.initializers_v1 import LecunUniform
from tensorflow.python.keras.initializers.initializers_v1 import RandomNormal
from tensorflow.python.keras.initializers.initializers_v1 import RandomUniform
from tensorflow.python.keras.initializers.initializers_v1 import TruncatedNormal
# Compatibility aliases def populate_deserializable_objects():
glorot_normal = GlorotNormal """Populates dict ALL_OBJECTS with every built-in initializer.
glorot_uniform = GlorotUniform """
he_normal = HeNormal global LOCAL
he_uniform = HeUniform if not hasattr(LOCAL, 'ALL_OBJECTS'):
lecun_normal = LecunNormal LOCAL.ALL_OBJECTS = {}
lecun_uniform = LecunUniform LOCAL.GENERATED_WITH_V2 = None
zero = zeros = Zeros
one = ones = Ones
constant = Constant
uniform = random_uniform = RandomUniform
normal = random_normal = RandomNormal
truncated_normal = TruncatedNormal
identity = Identity
orthogonal = Orthogonal
# For unit tests if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
glorot_normalV2 = GlorotNormalV2 # Objects dict is already generated for the proper TF version:
glorot_uniformV2 = GlorotUniformV2 # do nothing.
he_normalV2 = HeNormalV2 return
he_uniformV2 = HeUniformV2
lecun_normalV2 = LecunNormalV2 LOCAL.ALL_OBJECTS = {}
lecun_uniformV2 = LecunUniformV2 LOCAL.GENERATED_WITH_V2 = tf2.enabled()
# Compatibility aliases (need to exist in both V1 and V2).
LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant
LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal
LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform
LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal
LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform
LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity
LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal
LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform
LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones
LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal
LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal
LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform
LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal
LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling
LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros
# Out of an abundance of caution we also include these aliases that have
# a non-zero probability of having been included in saved configs in the past.
LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal
LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform
LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal
LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform
LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal
LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform
if tf2.enabled():
# For V2, entries are generated automatically based on the content of
# initializers_v2.py.
v2_objs = {}
base_cls = initializers_v2.Initializer
generic_utils.populate_dict_with_module_objects(
v2_objs,
[initializers_v2],
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
for key, value in v2_objs.items():
LOCAL.ALL_OBJECTS[key] = value
# Functional aliases.
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
else:
# V1 initializers.
v1_objs = {
'Constant': init_ops.Constant,
'GlorotNormal': init_ops.GlorotNormal,
'GlorotUniform': init_ops.GlorotUniform,
'Identity': init_ops.Identity,
'Ones': init_ops.Ones,
'Orthogonal': init_ops.Orthogonal,
'VarianceScaling': init_ops.VarianceScaling,
'Zeros': init_ops.Zeros,
'HeNormal': initializers_v1.HeNormal,
'HeUniform': initializers_v1.HeUniform,
'LecunNormal': initializers_v1.LecunNormal,
'LecunUniform': initializers_v1.LecunUniform,
'RandomNormal': initializers_v1.RandomNormal,
'RandomUniform': initializers_v1.RandomUniform,
'TruncatedNormal': initializers_v1.TruncatedNormal,
}
for key, value in v1_objs.items():
LOCAL.ALL_OBJECTS[key] = value
# Functional aliases.
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
# More compatibility aliases.
LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal']
LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform']
LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones']
LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros']
# For backwards compatibility, we populate this file with the objects
# from ALL_OBJECTS. We make no guarantees as to whether these objects will
# using their correct version.
populate_deserializable_objects()
globals().update(LOCAL.ALL_OBJECTS)
# Utility functions # Utility functions
@keras_export('keras.initializers.serialize') @keras_export('keras.initializers.serialize')
def serialize(initializer): def serialize(initializer):
return serialize_keras_object(initializer) return generic_utils.serialize_keras_object(initializer)
@keras_export('keras.initializers.deserialize') @keras_export('keras.initializers.deserialize')
def deserialize(config, custom_objects=None): def deserialize(config, custom_objects=None):
"""Return an `Initializer` object from its config.""" """Return an `Initializer` object from its config."""
module_objects = globals() populate_deserializable_objects()
return deserialize_keras_object( return generic_utils.deserialize_keras_object(
config, config,
module_objects=module_objects, module_objects=LOCAL.ALL_OBJECTS,
custom_objects=custom_objects, custom_objects=custom_objects,
printable_module_name='initializer') printable_module_name='initializer')

View File

@ -91,7 +91,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, _ = init_ops._compute_fans(tensor_shape) fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(1. / fan_in) std = np.sqrt(1. / fan_in)
self._runner( self._runner(
initializers.lecun_uniformV2(seed=123), initializers.LecunUniformV2(seed=123),
tensor_shape, tensor_shape,
target_mean=0., target_mean=0.,
target_std=std) target_std=std)
@ -113,7 +113,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, _ = init_ops._compute_fans(tensor_shape) fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / fan_in) std = np.sqrt(2. / fan_in)
self._runner( self._runner(
initializers.he_uniformV2(seed=123), initializers.HeUniformV2(seed=123),
tensor_shape, tensor_shape,
target_mean=0., target_mean=0.,
target_std=std) target_std=std)
@ -124,7 +124,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, _ = init_ops._compute_fans(tensor_shape) fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(1. / fan_in) std = np.sqrt(1. / fan_in)
self._runner( self._runner(
initializers.lecun_normalV2(seed=123), initializers.LecunNormalV2(seed=123),
tensor_shape, tensor_shape,
target_mean=0., target_mean=0.,
target_std=std) target_std=std)
@ -146,7 +146,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, _ = init_ops._compute_fans(tensor_shape) fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / fan_in) std = np.sqrt(2. / fan_in)
self._runner( self._runner(
initializers.he_normalV2(seed=123), initializers.HeNormalV2(seed=123),
tensor_shape, tensor_shape,
target_mean=0., target_mean=0.,
target_std=std) target_std=std)

View File

@ -569,7 +569,7 @@ class ConvLSTM2DCell(DropoutRNNCellMixin, Layer):
def bias_initializer(_, *args, **kwargs): def bias_initializer(_, *args, **kwargs):
return K.concatenate([ return K.concatenate([
self.bias_initializer((self.filters,), *args, **kwargs), self.bias_initializer((self.filters,), *args, **kwargs),
initializers.Ones()((self.filters,), *args, **kwargs), initializers.get('ones')((self.filters,), *args, **kwargs),
self.bias_initializer((self.filters * 2,), *args, **kwargs), self.bias_initializer((self.filters * 2,), *args, **kwargs),
]) ])
else: else:

View File

@ -2340,7 +2340,7 @@ class LSTMCell(DropoutRNNCellMixin, Layer):
def bias_initializer(_, *args, **kwargs): def bias_initializer(_, *args, **kwargs):
return K.concatenate([ return K.concatenate([
self.bias_initializer((self.units,), *args, **kwargs), self.bias_initializer((self.units,), *args, **kwargs),
initializers.Ones()((self.units,), *args, **kwargs), initializers.get('ones')((self.units,), *args, **kwargs),
self.bias_initializer((self.units * 2,), *args, **kwargs), self.bias_initializer((self.units * 2,), *args, **kwargs),
]) ])
else: else:

View File

@ -21,51 +21,125 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import threading
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.keras.engine.base_layer import AddLoss from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine.base_layer import AddMetric from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer from tensorflow.python.keras.layers import advanced_activations
from tensorflow.python.keras.engine.input_layer import Input from tensorflow.python.keras.layers import convolutional
from tensorflow.python.keras.engine.input_layer import InputLayer from tensorflow.python.keras.layers import convolutional_recurrent
from tensorflow.python.keras.layers.advanced_activations import * from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers.convolutional import * from tensorflow.python.keras.layers import cudnn_recurrent
from tensorflow.python.keras.layers.convolutional_recurrent import * from tensorflow.python.keras.layers import dense_attention
from tensorflow.python.keras.layers.core import * from tensorflow.python.keras.layers import embeddings
from tensorflow.python.keras.layers.cudnn_recurrent import * from tensorflow.python.keras.layers import local
from tensorflow.python.keras.layers.dense_attention import * from tensorflow.python.keras.layers import merge
from tensorflow.python.keras.layers.embeddings import * from tensorflow.python.keras.layers import noise
from tensorflow.python.keras.layers.local import * from tensorflow.python.keras.layers import normalization
from tensorflow.python.keras.layers.merge import * from tensorflow.python.keras.layers import normalization_v2
from tensorflow.python.keras.layers.noise import * from tensorflow.python.keras.layers import pooling
from tensorflow.python.keras.layers.normalization import * from tensorflow.python.keras.layers import recurrent
from tensorflow.python.keras.layers.pooling import * from tensorflow.python.keras.layers import recurrent_v2
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import * from tensorflow.python.keras.layers import rnn_cell_wrapper_v2
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import * from tensorflow.python.keras.layers import wrappers
from tensorflow.python.keras.layers.recurrent import * from tensorflow.python.keras.layers.preprocessing import image_preprocessing
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import * from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization
from tensorflow.python.keras.layers.wrappers import * from tensorflow.python.keras.layers.preprocessing import normalization_v1 as preprocessing_normalization_v1
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.util import tf_inspect as inspect
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
if tf2.enabled():
from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.layers.preprocessing.normalization import * # pylint: disable=g-import-not-at-top
# This deserialization table is added for backward compatibility, as in TF 1.13, ALL_MODULES = (
# BatchNormalizationV1 and BatchNormalizationV2 are used as class name for v1 base_layer,
# and v2 version of BatchNormalization, respectively. Here we explicitly convert input_layer,
# them to the canonical name in the config of deserialization. advanced_activations,
_DESERIALIZATION_TABLE = { convolutional,
'BatchNormalizationV1': 'BatchNormalization', convolutional_recurrent,
'BatchNormalizationV2': 'BatchNormalization', core,
} cudnn_recurrent,
dense_attention,
embeddings,
local,
merge,
noise,
normalization,
pooling,
image_preprocessing,
preprocessing_normalization_v1,
recurrent,
wrappers
)
ALL_V2_MODULES = (
rnn_cell_wrapper_v2,
normalization_v2,
recurrent_v2,
preprocessing_normalization
)
# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
# thread-local to avoid concurrent mutations.
LOCAL = threading.local()
def populate_deserializable_objects():
"""Populates dict ALL_OBJECTS with every built-in layer.
"""
global LOCAL
if not hasattr(LOCAL, 'ALL_OBJECTS'):
LOCAL.ALL_OBJECTS = {}
LOCAL.GENERATED_WITH_V2 = None
if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
# Objects dict is already generated for the proper TF version:
# do nothing.
return
LOCAL.ALL_OBJECTS = {}
LOCAL.GENERATED_WITH_V2 = tf2.enabled()
base_cls = base_layer.Layer
generic_utils.populate_dict_with_module_objects(
LOCAL.ALL_OBJECTS,
ALL_MODULES,
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
# Overwrite certain V1 objects with V2 versions
if tf2.enabled():
generic_utils.populate_dict_with_module_objects(
LOCAL.ALL_OBJECTS,
ALL_V2_MODULES,
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
# These deserialization aliases are added for backward compatibility,
# as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2"
# were used as class name for v1 and v2 version of BatchNormalization,
# respectively. Here we explicitly convert them to their canonical names.
LOCAL.ALL_OBJECTS['BatchNormalizationV1'] = normalization.BatchNormalization
LOCAL.ALL_OBJECTS[
'BatchNormalizationV2'] = normalization_v2.BatchNormalization
# Prevent circular dependencies.
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top
from tensorflow.python.feature_column import dense_features # pylint: disable=g-import-not-at-top
from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top
LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
LOCAL.ALL_OBJECTS['Network'] = models.Network
LOCAL.ALL_OBJECTS['Model'] = models.Model
LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential
LOCAL.ALL_OBJECTS['LinearModel'] = LinearModel
LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel
LOCAL.ALL_OBJECTS['DenseFeatures'] = dense_features.DenseFeatures
LOCAL.ALL_OBJECTS['SequenceFeatures'] = sfc.SequenceFeatures
@keras_export('keras.layers.serialize') @keras_export('keras.layers.serialize')
def serialize(layer): def serialize(layer):
return serialize_keras_object(layer) return generic_utils.serialize_keras_object(layer)
@keras_export('keras.layers.deserialize') @keras_export('keras.layers.deserialize')
@ -80,30 +154,9 @@ def deserialize(config, custom_objects=None):
Returns: Returns:
Layer instance (may be Model, Sequential, Network, Layer...) Layer instance (may be Model, Sequential, Network, Layer...)
""" """
# Prevent circular dependencies. populate_deserializable_objects()
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top return generic_utils.deserialize_keras_object(
from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top
from tensorflow.python.feature_column import dense_features # pylint: disable=g-import-not-at-top
from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top
globs = globals() # All layers.
globs['Network'] = models.Network
globs['Model'] = models.Model
globs['Sequential'] = models.Sequential
globs['LinearModel'] = LinearModel
globs['WideDeepModel'] = WideDeepModel
# Prevent circular dependencies with FeatureColumn serialization.
globs['DenseFeatures'] = dense_features.DenseFeatures
globs['SequenceFeatures'] = sfc.SequenceFeatures
layer_class_name = config['class_name']
if layer_class_name in _DESERIALIZATION_TABLE:
config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name]
return deserialize_keras_object(
config, config,
module_objects=globs, module_objects=LOCAL.ALL_OBJECTS,
custom_objects=custom_objects, custom_objects=custom_objects,
printable_module_name='layer') printable_module_name='layer')

View File

@ -124,12 +124,6 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase):
layer = batchnorm_layer( layer = batchnorm_layer(
momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2') momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2')
config = keras.layers.serialize(layer) config = keras.layers.serialize(layer)
# To simulate if BatchNormalizationV1 or BatchNormalizationV2 appears in the
# saved model.
if batchnorm_layer is batchnorm_v1.BatchNormalization:
config['class_name'] = 'BatchNormalizationV1'
else:
config['class_name'] = 'BatchNormalizationV2'
new_layer = keras.layers.deserialize(config) new_layer = keras.layers.deserialize(config)
self.assertEqual(new_layer.momentum, 0.9) self.assertEqual(new_layer.momentum, 0.9)
if tf2.enabled(): if tf2.enabled():

View File

@ -364,7 +364,8 @@ def deserialize_keras_object(identifier,
else: else:
obj = module_objects.get(object_name) obj = module_objects.get(object_name)
if obj is None: if obj is None:
raise ValueError('Unknown ' + printable_module_name + ':' + object_name) raise ValueError(
'Unknown ' + printable_module_name + ': ' + object_name)
# Classes passed by name are instantiated with no args, functions are # Classes passed by name are instantiated with no args, functions are
# returned as-is. # returned as-is.
if tf_inspect.isclass(obj): if tf_inspect.isclass(obj):
@ -783,6 +784,13 @@ def is_default(method):
return getattr(method, '_is_default', False) return getattr(method, '_is_default', False)
def populate_dict_with_module_objects(target_dict, modules, obj_filter):
for module in modules:
for name in dir(module):
obj = getattr(module, name)
if obj_filter(obj):
target_dict[name] = obj
# Aliases # Aliases