Make Keras deserialization aware of enable/disable_v2_behavior switches.

PiperOrigin-RevId: 303355879
Change-Id: I3b8f7ae7db4ed1185585c6d56ed445cd4c61b6ca
This commit is contained in:
Francois Chollet 2020-03-27 10:20:15 -07:00 committed by TensorFlower Gardener
parent cd59b293b5
commit 6766a8ee3e
9 changed files with 235 additions and 160 deletions

View File

@ -538,11 +538,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
if initializer is None:
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
if dtype.is_floating:
initializer = initializers.glorot_uniform()
initializer = initializers.get('glorot_uniform')
# If dtype is DT_INT/DT_UINT, provide a default value `zero`
# If dtype is DT_BOOL, provide a default value `FALSE`
elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
initializer = initializers.zeros()
initializer = initializers.get('zeros')
# NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
else:
raise ValueError('An initializer for variable %s of type %s is required'

View File

@ -402,7 +402,7 @@ class Layer(base_layer.Layer):
if initializer is None:
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
if dtype.is_floating:
initializer = initializers.glorot_uniform()
initializer = initializers.get('glorot_uniform')
# If dtype is DT_INT/DT_UINT, provide a default value `zero`
# If dtype is DT_BOOL, provide a default value `FALSE`
elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:

View File

@ -14,114 +14,134 @@
# ==============================================================================
"""Keras initializer serialization / deserialization.
"""
# pylint: disable=unused-import
# pylint: disable=line-too-long
# pylint: disable=g-import-not-at-top
# pylint: disable=g-bad-import-order
# pylint: disable=invalid-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import six
from tensorflow.python import tf2
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.keras.initializers import initializers_v1
from tensorflow.python.keras.initializers import initializers_v2
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import init_ops
from tensorflow.python.util import tf_inspect as inspect
from tensorflow.python.util.tf_export import keras_export
# These imports are brought in so that keras.initializers.deserialize
# has them available in module_objects.
from tensorflow.python.keras.initializers.initializers_v2 import Constant as ConstantV2
from tensorflow.python.keras.initializers.initializers_v2 import GlorotNormal as GlorotNormalV2
from tensorflow.python.keras.initializers.initializers_v2 import GlorotUniform as GlorotUniformV2
from tensorflow.python.keras.initializers.initializers_v2 import HeNormal as HeNormalV2
from tensorflow.python.keras.initializers.initializers_v2 import HeUniform as HeUniformV2
from tensorflow.python.keras.initializers.initializers_v2 import Identity as IdentityV2
from tensorflow.python.keras.initializers.initializers_v2 import Initializer
from tensorflow.python.keras.initializers.initializers_v2 import LecunNormal as LecunNormalV2
from tensorflow.python.keras.initializers.initializers_v2 import LecunUniform as LecunUniformV2
from tensorflow.python.keras.initializers.initializers_v2 import Ones as OnesV2
from tensorflow.python.keras.initializers.initializers_v2 import Orthogonal as OrthogonalV2
from tensorflow.python.keras.initializers.initializers_v2 import RandomNormal as RandomNormalV2
from tensorflow.python.keras.initializers.initializers_v2 import RandomUniform as RandomUniformV2
from tensorflow.python.keras.initializers.initializers_v2 import TruncatedNormal as TruncatedNormalV2
from tensorflow.python.keras.initializers.initializers_v2 import VarianceScaling as VarianceScalingV2
from tensorflow.python.keras.initializers.initializers_v2 import Zeros as ZerosV2
if tf2.enabled():
Constant = ConstantV2
GlorotNormal = GlorotNormalV2
GlorotUniform = GlorotUniformV2
HeNormal = HeNormalV2
HeUniform = HeUniformV2
Identity = IdentityV2
LecunNormal = LecunNormalV2
LecunUniform = LecunUniformV2
Ones = OnesV2
Orthogonal = OrthogonalV2
RandomNormal = RandomNormalV2
RandomUniform = RandomUniformV2
TruncatedNormal = TruncatedNormalV2
VarianceScaling = VarianceScalingV2
Zeros = ZerosV2
else:
from tensorflow.python.ops.init_ops import Constant
from tensorflow.python.ops.init_ops import GlorotNormal
from tensorflow.python.ops.init_ops import GlorotUniform
from tensorflow.python.ops.init_ops import Identity
from tensorflow.python.ops.init_ops import Ones
from tensorflow.python.ops.init_ops import Orthogonal
from tensorflow.python.ops.init_ops import VarianceScaling
from tensorflow.python.ops.init_ops import Zeros
from tensorflow.python.keras.initializers.initializers_v1 import HeNormal
from tensorflow.python.keras.initializers.initializers_v1 import HeUniform
from tensorflow.python.keras.initializers.initializers_v1 import LecunNormal
from tensorflow.python.keras.initializers.initializers_v1 import LecunUniform
from tensorflow.python.keras.initializers.initializers_v1 import RandomNormal
from tensorflow.python.keras.initializers.initializers_v1 import RandomUniform
from tensorflow.python.keras.initializers.initializers_v1 import TruncatedNormal
# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
# thread-local to avoid concurrent mutations.
LOCAL = threading.local()
# Compatibility aliases
glorot_normal = GlorotNormal
glorot_uniform = GlorotUniform
he_normal = HeNormal
he_uniform = HeUniform
lecun_normal = LecunNormal
lecun_uniform = LecunUniform
zero = zeros = Zeros
one = ones = Ones
constant = Constant
uniform = random_uniform = RandomUniform
normal = random_normal = RandomNormal
truncated_normal = TruncatedNormal
identity = Identity
orthogonal = Orthogonal
def populate_deserializable_objects():
"""Populates dict ALL_OBJECTS with every built-in initializer.
"""
global LOCAL
if not hasattr(LOCAL, 'ALL_OBJECTS'):
LOCAL.ALL_OBJECTS = {}
LOCAL.GENERATED_WITH_V2 = None
# For unit tests
glorot_normalV2 = GlorotNormalV2
glorot_uniformV2 = GlorotUniformV2
he_normalV2 = HeNormalV2
he_uniformV2 = HeUniformV2
lecun_normalV2 = LecunNormalV2
lecun_uniformV2 = LecunUniformV2
if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
# Objects dict is already generated for the proper TF version:
# do nothing.
return
LOCAL.ALL_OBJECTS = {}
LOCAL.GENERATED_WITH_V2 = tf2.enabled()
# Compatibility aliases (need to exist in both V1 and V2).
LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant
LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal
LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform
LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal
LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform
LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity
LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal
LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform
LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones
LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal
LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal
LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform
LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal
LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling
LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros
# Out of an abundance of caution we also include these aliases that have
# a non-zero probability of having been included in saved configs in the past.
LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal
LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform
LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal
LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform
LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal
LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform
if tf2.enabled():
# For V2, entries are generated automatically based on the content of
# initializers_v2.py.
v2_objs = {}
base_cls = initializers_v2.Initializer
generic_utils.populate_dict_with_module_objects(
v2_objs,
[initializers_v2],
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
for key, value in v2_objs.items():
LOCAL.ALL_OBJECTS[key] = value
# Functional aliases.
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
else:
# V1 initializers.
v1_objs = {
'Constant': init_ops.Constant,
'GlorotNormal': init_ops.GlorotNormal,
'GlorotUniform': init_ops.GlorotUniform,
'Identity': init_ops.Identity,
'Ones': init_ops.Ones,
'Orthogonal': init_ops.Orthogonal,
'VarianceScaling': init_ops.VarianceScaling,
'Zeros': init_ops.Zeros,
'HeNormal': initializers_v1.HeNormal,
'HeUniform': initializers_v1.HeUniform,
'LecunNormal': initializers_v1.LecunNormal,
'LecunUniform': initializers_v1.LecunUniform,
'RandomNormal': initializers_v1.RandomNormal,
'RandomUniform': initializers_v1.RandomUniform,
'TruncatedNormal': initializers_v1.TruncatedNormal,
}
for key, value in v1_objs.items():
LOCAL.ALL_OBJECTS[key] = value
# Functional aliases.
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
# More compatibility aliases.
LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal']
LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform']
LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones']
LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros']
# For backwards compatibility, we populate this file with the objects
# from ALL_OBJECTS. We make no guarantees as to whether these objects will
# using their correct version.
populate_deserializable_objects()
globals().update(LOCAL.ALL_OBJECTS)
# Utility functions
@keras_export('keras.initializers.serialize')
def serialize(initializer):
return serialize_keras_object(initializer)
return generic_utils.serialize_keras_object(initializer)
@keras_export('keras.initializers.deserialize')
def deserialize(config, custom_objects=None):
"""Return an `Initializer` object from its config."""
module_objects = globals()
return deserialize_keras_object(
populate_deserializable_objects()
return generic_utils.deserialize_keras_object(
config,
module_objects=module_objects,
module_objects=LOCAL.ALL_OBJECTS,
custom_objects=custom_objects,
printable_module_name='initializer')

View File

@ -91,7 +91,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(1. / fan_in)
self._runner(
initializers.lecun_uniformV2(seed=123),
initializers.LecunUniformV2(seed=123),
tensor_shape,
target_mean=0.,
target_std=std)
@ -113,7 +113,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / fan_in)
self._runner(
initializers.he_uniformV2(seed=123),
initializers.HeUniformV2(seed=123),
tensor_shape,
target_mean=0.,
target_std=std)
@ -124,7 +124,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(1. / fan_in)
self._runner(
initializers.lecun_normalV2(seed=123),
initializers.LecunNormalV2(seed=123),
tensor_shape,
target_mean=0.,
target_std=std)
@ -146,7 +146,7 @@ class KerasInitializersTest(test.TestCase):
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / fan_in)
self._runner(
initializers.he_normalV2(seed=123),
initializers.HeNormalV2(seed=123),
tensor_shape,
target_mean=0.,
target_std=std)

View File

@ -569,7 +569,7 @@ class ConvLSTM2DCell(DropoutRNNCellMixin, Layer):
def bias_initializer(_, *args, **kwargs):
return K.concatenate([
self.bias_initializer((self.filters,), *args, **kwargs),
initializers.Ones()((self.filters,), *args, **kwargs),
initializers.get('ones')((self.filters,), *args, **kwargs),
self.bias_initializer((self.filters * 2,), *args, **kwargs),
])
else:

View File

@ -2340,7 +2340,7 @@ class LSTMCell(DropoutRNNCellMixin, Layer):
def bias_initializer(_, *args, **kwargs):
return K.concatenate([
self.bias_initializer((self.units,), *args, **kwargs),
initializers.Ones()((self.units,), *args, **kwargs),
initializers.get('ones')((self.units,), *args, **kwargs),
self.bias_initializer((self.units * 2,), *args, **kwargs),
])
else:

View File

@ -21,51 +21,125 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from tensorflow.python import tf2
from tensorflow.python.keras.engine.base_layer import AddLoss
from tensorflow.python.keras.engine.base_layer import AddMetric
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
from tensorflow.python.keras.layers.advanced_activations import *
from tensorflow.python.keras.layers.convolutional import *
from tensorflow.python.keras.layers.convolutional_recurrent import *
from tensorflow.python.keras.layers.core import *
from tensorflow.python.keras.layers.cudnn_recurrent import *
from tensorflow.python.keras.layers.dense_attention import *
from tensorflow.python.keras.layers.embeddings import *
from tensorflow.python.keras.layers.local import *
from tensorflow.python.keras.layers.merge import *
from tensorflow.python.keras.layers.noise import *
from tensorflow.python.keras.layers.normalization import *
from tensorflow.python.keras.layers.pooling import *
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import *
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import *
from tensorflow.python.keras.layers.recurrent import *
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import *
from tensorflow.python.keras.layers.wrappers import *
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.layers import advanced_activations
from tensorflow.python.keras.layers import convolutional
from tensorflow.python.keras.layers import convolutional_recurrent
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers import cudnn_recurrent
from tensorflow.python.keras.layers import dense_attention
from tensorflow.python.keras.layers import embeddings
from tensorflow.python.keras.layers import local
from tensorflow.python.keras.layers import merge
from tensorflow.python.keras.layers import noise
from tensorflow.python.keras.layers import normalization
from tensorflow.python.keras.layers import normalization_v2
from tensorflow.python.keras.layers import pooling
from tensorflow.python.keras.layers import recurrent
from tensorflow.python.keras.layers import recurrent_v2
from tensorflow.python.keras.layers import rnn_cell_wrapper_v2
from tensorflow.python.keras.layers import wrappers
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization
from tensorflow.python.keras.layers.preprocessing import normalization_v1 as preprocessing_normalization_v1
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.util import tf_inspect as inspect
from tensorflow.python.util.tf_export import keras_export
if tf2.enabled():
from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.layers.preprocessing.normalization import * # pylint: disable=g-import-not-at-top
# This deserialization table is added for backward compatibility, as in TF 1.13,
# BatchNormalizationV1 and BatchNormalizationV2 are used as class name for v1
# and v2 version of BatchNormalization, respectively. Here we explicitly convert
# them to the canonical name in the config of deserialization.
_DESERIALIZATION_TABLE = {
'BatchNormalizationV1': 'BatchNormalization',
'BatchNormalizationV2': 'BatchNormalization',
}
ALL_MODULES = (
base_layer,
input_layer,
advanced_activations,
convolutional,
convolutional_recurrent,
core,
cudnn_recurrent,
dense_attention,
embeddings,
local,
merge,
noise,
normalization,
pooling,
image_preprocessing,
preprocessing_normalization_v1,
recurrent,
wrappers
)
ALL_V2_MODULES = (
rnn_cell_wrapper_v2,
normalization_v2,
recurrent_v2,
preprocessing_normalization
)
# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
# thread-local to avoid concurrent mutations.
LOCAL = threading.local()
def populate_deserializable_objects():
"""Populates dict ALL_OBJECTS with every built-in layer.
"""
global LOCAL
if not hasattr(LOCAL, 'ALL_OBJECTS'):
LOCAL.ALL_OBJECTS = {}
LOCAL.GENERATED_WITH_V2 = None
if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
# Objects dict is already generated for the proper TF version:
# do nothing.
return
LOCAL.ALL_OBJECTS = {}
LOCAL.GENERATED_WITH_V2 = tf2.enabled()
base_cls = base_layer.Layer
generic_utils.populate_dict_with_module_objects(
LOCAL.ALL_OBJECTS,
ALL_MODULES,
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
# Overwrite certain V1 objects with V2 versions
if tf2.enabled():
generic_utils.populate_dict_with_module_objects(
LOCAL.ALL_OBJECTS,
ALL_V2_MODULES,
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
# These deserialization aliases are added for backward compatibility,
# as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2"
# were used as class name for v1 and v2 version of BatchNormalization,
# respectively. Here we explicitly convert them to their canonical names.
LOCAL.ALL_OBJECTS['BatchNormalizationV1'] = normalization.BatchNormalization
LOCAL.ALL_OBJECTS[
'BatchNormalizationV2'] = normalization_v2.BatchNormalization
# Prevent circular dependencies.
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top
from tensorflow.python.feature_column import dense_features # pylint: disable=g-import-not-at-top
from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top
LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
LOCAL.ALL_OBJECTS['Network'] = models.Network
LOCAL.ALL_OBJECTS['Model'] = models.Model
LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential
LOCAL.ALL_OBJECTS['LinearModel'] = LinearModel
LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel
LOCAL.ALL_OBJECTS['DenseFeatures'] = dense_features.DenseFeatures
LOCAL.ALL_OBJECTS['SequenceFeatures'] = sfc.SequenceFeatures
@keras_export('keras.layers.serialize')
def serialize(layer):
return serialize_keras_object(layer)
return generic_utils.serialize_keras_object(layer)
@keras_export('keras.layers.deserialize')
@ -80,30 +154,9 @@ def deserialize(config, custom_objects=None):
Returns:
Layer instance (may be Model, Sequential, Network, Layer...)
"""
# Prevent circular dependencies.
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top
from tensorflow.python.feature_column import dense_features # pylint: disable=g-import-not-at-top
from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top
globs = globals() # All layers.
globs['Network'] = models.Network
globs['Model'] = models.Model
globs['Sequential'] = models.Sequential
globs['LinearModel'] = LinearModel
globs['WideDeepModel'] = WideDeepModel
# Prevent circular dependencies with FeatureColumn serialization.
globs['DenseFeatures'] = dense_features.DenseFeatures
globs['SequenceFeatures'] = sfc.SequenceFeatures
layer_class_name = config['class_name']
if layer_class_name in _DESERIALIZATION_TABLE:
config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name]
return deserialize_keras_object(
populate_deserializable_objects()
return generic_utils.deserialize_keras_object(
config,
module_objects=globs,
module_objects=LOCAL.ALL_OBJECTS,
custom_objects=custom_objects,
printable_module_name='layer')

View File

@ -124,12 +124,6 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase):
layer = batchnorm_layer(
momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2')
config = keras.layers.serialize(layer)
# To simulate if BatchNormalizationV1 or BatchNormalizationV2 appears in the
# saved model.
if batchnorm_layer is batchnorm_v1.BatchNormalization:
config['class_name'] = 'BatchNormalizationV1'
else:
config['class_name'] = 'BatchNormalizationV2'
new_layer = keras.layers.deserialize(config)
self.assertEqual(new_layer.momentum, 0.9)
if tf2.enabled():

View File

@ -364,7 +364,8 @@ def deserialize_keras_object(identifier,
else:
obj = module_objects.get(object_name)
if obj is None:
raise ValueError('Unknown ' + printable_module_name + ':' + object_name)
raise ValueError(
'Unknown ' + printable_module_name + ': ' + object_name)
# Classes passed by name are instantiated with no args, functions are
# returned as-is.
if tf_inspect.isclass(obj):
@ -783,6 +784,13 @@ def is_default(method):
return getattr(method, '_is_default', False)
def populate_dict_with_module_objects(target_dict, modules, obj_filter):
for module in modules:
for name in dir(module):
obj = getattr(module, name)
if obj_filter(obj):
target_dict[name] = obj
# Aliases