Make Keras deserialization aware of enable/disable_v2_behavior switches.
PiperOrigin-RevId: 303355879 Change-Id: I3b8f7ae7db4ed1185585c6d56ed445cd4c61b6ca
This commit is contained in:
parent
cd59b293b5
commit
6766a8ee3e
@ -538,11 +538,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
if initializer is None:
|
||||
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
|
||||
if dtype.is_floating:
|
||||
initializer = initializers.glorot_uniform()
|
||||
initializer = initializers.get('glorot_uniform')
|
||||
# If dtype is DT_INT/DT_UINT, provide a default value `zero`
|
||||
# If dtype is DT_BOOL, provide a default value `FALSE`
|
||||
elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
|
||||
initializer = initializers.zeros()
|
||||
initializer = initializers.get('zeros')
|
||||
# NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
|
||||
else:
|
||||
raise ValueError('An initializer for variable %s of type %s is required'
|
||||
|
@ -402,7 +402,7 @@ class Layer(base_layer.Layer):
|
||||
if initializer is None:
|
||||
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
|
||||
if dtype.is_floating:
|
||||
initializer = initializers.glorot_uniform()
|
||||
initializer = initializers.get('glorot_uniform')
|
||||
# If dtype is DT_INT/DT_UINT, provide a default value `zero`
|
||||
# If dtype is DT_BOOL, provide a default value `FALSE`
|
||||
elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
|
||||
|
@ -14,114 +14,134 @@
|
||||
# ==============================================================================
|
||||
"""Keras initializer serialization / deserialization.
|
||||
"""
|
||||
# pylint: disable=unused-import
|
||||
# pylint: disable=line-too-long
|
||||
# pylint: disable=g-import-not-at-top
|
||||
# pylint: disable=g-bad-import-order
|
||||
# pylint: disable=invalid-name
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
import six
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.keras.initializers import initializers_v1
|
||||
from tensorflow.python.keras.initializers import initializers_v2
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.util import tf_inspect as inspect
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
# These imports are brought in so that keras.initializers.deserialize
|
||||
# has them available in module_objects.
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import Constant as ConstantV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import GlorotNormal as GlorotNormalV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import GlorotUniform as GlorotUniformV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import HeNormal as HeNormalV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import HeUniform as HeUniformV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import Identity as IdentityV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import Initializer
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import LecunNormal as LecunNormalV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import LecunUniform as LecunUniformV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import Ones as OnesV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import Orthogonal as OrthogonalV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import RandomNormal as RandomNormalV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import RandomUniform as RandomUniformV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import TruncatedNormal as TruncatedNormalV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import VarianceScaling as VarianceScalingV2
|
||||
from tensorflow.python.keras.initializers.initializers_v2 import Zeros as ZerosV2
|
||||
|
||||
if tf2.enabled():
|
||||
Constant = ConstantV2
|
||||
GlorotNormal = GlorotNormalV2
|
||||
GlorotUniform = GlorotUniformV2
|
||||
HeNormal = HeNormalV2
|
||||
HeUniform = HeUniformV2
|
||||
Identity = IdentityV2
|
||||
LecunNormal = LecunNormalV2
|
||||
LecunUniform = LecunUniformV2
|
||||
Ones = OnesV2
|
||||
Orthogonal = OrthogonalV2
|
||||
RandomNormal = RandomNormalV2
|
||||
RandomUniform = RandomUniformV2
|
||||
TruncatedNormal = TruncatedNormalV2
|
||||
VarianceScaling = VarianceScalingV2
|
||||
Zeros = ZerosV2
|
||||
else:
|
||||
from tensorflow.python.ops.init_ops import Constant
|
||||
from tensorflow.python.ops.init_ops import GlorotNormal
|
||||
from tensorflow.python.ops.init_ops import GlorotUniform
|
||||
from tensorflow.python.ops.init_ops import Identity
|
||||
from tensorflow.python.ops.init_ops import Ones
|
||||
from tensorflow.python.ops.init_ops import Orthogonal
|
||||
from tensorflow.python.ops.init_ops import VarianceScaling
|
||||
from tensorflow.python.ops.init_ops import Zeros
|
||||
from tensorflow.python.keras.initializers.initializers_v1 import HeNormal
|
||||
from tensorflow.python.keras.initializers.initializers_v1 import HeUniform
|
||||
from tensorflow.python.keras.initializers.initializers_v1 import LecunNormal
|
||||
from tensorflow.python.keras.initializers.initializers_v1 import LecunUniform
|
||||
from tensorflow.python.keras.initializers.initializers_v1 import RandomNormal
|
||||
from tensorflow.python.keras.initializers.initializers_v1 import RandomUniform
|
||||
from tensorflow.python.keras.initializers.initializers_v1 import TruncatedNormal
|
||||
# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
|
||||
# thread-local to avoid concurrent mutations.
|
||||
LOCAL = threading.local()
|
||||
|
||||
|
||||
# Compatibility aliases
|
||||
glorot_normal = GlorotNormal
|
||||
glorot_uniform = GlorotUniform
|
||||
he_normal = HeNormal
|
||||
he_uniform = HeUniform
|
||||
lecun_normal = LecunNormal
|
||||
lecun_uniform = LecunUniform
|
||||
zero = zeros = Zeros
|
||||
one = ones = Ones
|
||||
constant = Constant
|
||||
uniform = random_uniform = RandomUniform
|
||||
normal = random_normal = RandomNormal
|
||||
truncated_normal = TruncatedNormal
|
||||
identity = Identity
|
||||
orthogonal = Orthogonal
|
||||
def populate_deserializable_objects():
|
||||
"""Populates dict ALL_OBJECTS with every built-in initializer.
|
||||
"""
|
||||
global LOCAL
|
||||
if not hasattr(LOCAL, 'ALL_OBJECTS'):
|
||||
LOCAL.ALL_OBJECTS = {}
|
||||
LOCAL.GENERATED_WITH_V2 = None
|
||||
|
||||
# For unit tests
|
||||
glorot_normalV2 = GlorotNormalV2
|
||||
glorot_uniformV2 = GlorotUniformV2
|
||||
he_normalV2 = HeNormalV2
|
||||
he_uniformV2 = HeUniformV2
|
||||
lecun_normalV2 = LecunNormalV2
|
||||
lecun_uniformV2 = LecunUniformV2
|
||||
if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
|
||||
# Objects dict is already generated for the proper TF version:
|
||||
# do nothing.
|
||||
return
|
||||
|
||||
LOCAL.ALL_OBJECTS = {}
|
||||
LOCAL.GENERATED_WITH_V2 = tf2.enabled()
|
||||
|
||||
# Compatibility aliases (need to exist in both V1 and V2).
|
||||
LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant
|
||||
LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal
|
||||
LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform
|
||||
LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal
|
||||
LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform
|
||||
LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity
|
||||
LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal
|
||||
LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform
|
||||
LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones
|
||||
LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal
|
||||
LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal
|
||||
LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform
|
||||
LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal
|
||||
LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling
|
||||
LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros
|
||||
|
||||
# Out of an abundance of caution we also include these aliases that have
|
||||
# a non-zero probability of having been included in saved configs in the past.
|
||||
LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal
|
||||
LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform
|
||||
LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal
|
||||
LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform
|
||||
LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal
|
||||
LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform
|
||||
|
||||
if tf2.enabled():
|
||||
# For V2, entries are generated automatically based on the content of
|
||||
# initializers_v2.py.
|
||||
v2_objs = {}
|
||||
base_cls = initializers_v2.Initializer
|
||||
generic_utils.populate_dict_with_module_objects(
|
||||
v2_objs,
|
||||
[initializers_v2],
|
||||
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
|
||||
for key, value in v2_objs.items():
|
||||
LOCAL.ALL_OBJECTS[key] = value
|
||||
# Functional aliases.
|
||||
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
|
||||
else:
|
||||
# V1 initializers.
|
||||
v1_objs = {
|
||||
'Constant': init_ops.Constant,
|
||||
'GlorotNormal': init_ops.GlorotNormal,
|
||||
'GlorotUniform': init_ops.GlorotUniform,
|
||||
'Identity': init_ops.Identity,
|
||||
'Ones': init_ops.Ones,
|
||||
'Orthogonal': init_ops.Orthogonal,
|
||||
'VarianceScaling': init_ops.VarianceScaling,
|
||||
'Zeros': init_ops.Zeros,
|
||||
'HeNormal': initializers_v1.HeNormal,
|
||||
'HeUniform': initializers_v1.HeUniform,
|
||||
'LecunNormal': initializers_v1.LecunNormal,
|
||||
'LecunUniform': initializers_v1.LecunUniform,
|
||||
'RandomNormal': initializers_v1.RandomNormal,
|
||||
'RandomUniform': initializers_v1.RandomUniform,
|
||||
'TruncatedNormal': initializers_v1.TruncatedNormal,
|
||||
}
|
||||
for key, value in v1_objs.items():
|
||||
LOCAL.ALL_OBJECTS[key] = value
|
||||
# Functional aliases.
|
||||
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
|
||||
|
||||
# More compatibility aliases.
|
||||
LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal']
|
||||
LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform']
|
||||
LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones']
|
||||
LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros']
|
||||
|
||||
|
||||
# For backwards compatibility, we populate this file with the objects
|
||||
# from ALL_OBJECTS. We make no guarantees as to whether these objects will
|
||||
# using their correct version.
|
||||
populate_deserializable_objects()
|
||||
globals().update(LOCAL.ALL_OBJECTS)
|
||||
|
||||
# Utility functions
|
||||
|
||||
|
||||
@keras_export('keras.initializers.serialize')
|
||||
def serialize(initializer):
|
||||
return serialize_keras_object(initializer)
|
||||
return generic_utils.serialize_keras_object(initializer)
|
||||
|
||||
|
||||
@keras_export('keras.initializers.deserialize')
|
||||
def deserialize(config, custom_objects=None):
|
||||
"""Return an `Initializer` object from its config."""
|
||||
module_objects = globals()
|
||||
return deserialize_keras_object(
|
||||
populate_deserializable_objects()
|
||||
return generic_utils.deserialize_keras_object(
|
||||
config,
|
||||
module_objects=module_objects,
|
||||
module_objects=LOCAL.ALL_OBJECTS,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='initializer')
|
||||
|
||||
|
@ -91,7 +91,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||
std = np.sqrt(1. / fan_in)
|
||||
self._runner(
|
||||
initializers.lecun_uniformV2(seed=123),
|
||||
initializers.LecunUniformV2(seed=123),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_std=std)
|
||||
@ -113,7 +113,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||
std = np.sqrt(2. / fan_in)
|
||||
self._runner(
|
||||
initializers.he_uniformV2(seed=123),
|
||||
initializers.HeUniformV2(seed=123),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_std=std)
|
||||
@ -124,7 +124,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||
std = np.sqrt(1. / fan_in)
|
||||
self._runner(
|
||||
initializers.lecun_normalV2(seed=123),
|
||||
initializers.LecunNormalV2(seed=123),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_std=std)
|
||||
@ -146,7 +146,7 @@ class KerasInitializersTest(test.TestCase):
|
||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||
std = np.sqrt(2. / fan_in)
|
||||
self._runner(
|
||||
initializers.he_normalV2(seed=123),
|
||||
initializers.HeNormalV2(seed=123),
|
||||
tensor_shape,
|
||||
target_mean=0.,
|
||||
target_std=std)
|
||||
|
@ -569,7 +569,7 @@ class ConvLSTM2DCell(DropoutRNNCellMixin, Layer):
|
||||
def bias_initializer(_, *args, **kwargs):
|
||||
return K.concatenate([
|
||||
self.bias_initializer((self.filters,), *args, **kwargs),
|
||||
initializers.Ones()((self.filters,), *args, **kwargs),
|
||||
initializers.get('ones')((self.filters,), *args, **kwargs),
|
||||
self.bias_initializer((self.filters * 2,), *args, **kwargs),
|
||||
])
|
||||
else:
|
||||
|
@ -2340,7 +2340,7 @@ class LSTMCell(DropoutRNNCellMixin, Layer):
|
||||
def bias_initializer(_, *args, **kwargs):
|
||||
return K.concatenate([
|
||||
self.bias_initializer((self.units,), *args, **kwargs),
|
||||
initializers.Ones()((self.units,), *args, **kwargs),
|
||||
initializers.get('ones')((self.units,), *args, **kwargs),
|
||||
self.bias_initializer((self.units * 2,), *args, **kwargs),
|
||||
])
|
||||
else:
|
||||
|
@ -21,51 +21,125 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.keras.engine.base_layer import AddLoss
|
||||
from tensorflow.python.keras.engine.base_layer import AddMetric
|
||||
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
|
||||
from tensorflow.python.keras.engine.input_layer import Input
|
||||
from tensorflow.python.keras.engine.input_layer import InputLayer
|
||||
from tensorflow.python.keras.layers.advanced_activations import *
|
||||
from tensorflow.python.keras.layers.convolutional import *
|
||||
from tensorflow.python.keras.layers.convolutional_recurrent import *
|
||||
from tensorflow.python.keras.layers.core import *
|
||||
from tensorflow.python.keras.layers.cudnn_recurrent import *
|
||||
from tensorflow.python.keras.layers.dense_attention import *
|
||||
from tensorflow.python.keras.layers.embeddings import *
|
||||
from tensorflow.python.keras.layers.local import *
|
||||
from tensorflow.python.keras.layers.merge import *
|
||||
from tensorflow.python.keras.layers.noise import *
|
||||
from tensorflow.python.keras.layers.normalization import *
|
||||
from tensorflow.python.keras.layers.pooling import *
|
||||
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import *
|
||||
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import *
|
||||
from tensorflow.python.keras.layers.recurrent import *
|
||||
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import *
|
||||
from tensorflow.python.keras.layers.wrappers import *
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.layers import advanced_activations
|
||||
from tensorflow.python.keras.layers import convolutional
|
||||
from tensorflow.python.keras.layers import convolutional_recurrent
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.layers import cudnn_recurrent
|
||||
from tensorflow.python.keras.layers import dense_attention
|
||||
from tensorflow.python.keras.layers import embeddings
|
||||
from tensorflow.python.keras.layers import local
|
||||
from tensorflow.python.keras.layers import merge
|
||||
from tensorflow.python.keras.layers import noise
|
||||
from tensorflow.python.keras.layers import normalization
|
||||
from tensorflow.python.keras.layers import normalization_v2
|
||||
from tensorflow.python.keras.layers import pooling
|
||||
from tensorflow.python.keras.layers import recurrent
|
||||
from tensorflow.python.keras.layers import recurrent_v2
|
||||
from tensorflow.python.keras.layers import rnn_cell_wrapper_v2
|
||||
from tensorflow.python.keras.layers import wrappers
|
||||
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||
from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization
|
||||
from tensorflow.python.keras.layers.preprocessing import normalization_v1 as preprocessing_normalization_v1
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.util import tf_inspect as inspect
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
if tf2.enabled():
|
||||
from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras.layers.preprocessing.normalization import * # pylint: disable=g-import-not-at-top
|
||||
|
||||
# This deserialization table is added for backward compatibility, as in TF 1.13,
|
||||
# BatchNormalizationV1 and BatchNormalizationV2 are used as class name for v1
|
||||
# and v2 version of BatchNormalization, respectively. Here we explicitly convert
|
||||
# them to the canonical name in the config of deserialization.
|
||||
_DESERIALIZATION_TABLE = {
|
||||
'BatchNormalizationV1': 'BatchNormalization',
|
||||
'BatchNormalizationV2': 'BatchNormalization',
|
||||
}
|
||||
ALL_MODULES = (
|
||||
base_layer,
|
||||
input_layer,
|
||||
advanced_activations,
|
||||
convolutional,
|
||||
convolutional_recurrent,
|
||||
core,
|
||||
cudnn_recurrent,
|
||||
dense_attention,
|
||||
embeddings,
|
||||
local,
|
||||
merge,
|
||||
noise,
|
||||
normalization,
|
||||
pooling,
|
||||
image_preprocessing,
|
||||
preprocessing_normalization_v1,
|
||||
recurrent,
|
||||
wrappers
|
||||
)
|
||||
ALL_V2_MODULES = (
|
||||
rnn_cell_wrapper_v2,
|
||||
normalization_v2,
|
||||
recurrent_v2,
|
||||
preprocessing_normalization
|
||||
)
|
||||
|
||||
# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
|
||||
# thread-local to avoid concurrent mutations.
|
||||
LOCAL = threading.local()
|
||||
|
||||
|
||||
def populate_deserializable_objects():
|
||||
"""Populates dict ALL_OBJECTS with every built-in layer.
|
||||
"""
|
||||
global LOCAL
|
||||
if not hasattr(LOCAL, 'ALL_OBJECTS'):
|
||||
LOCAL.ALL_OBJECTS = {}
|
||||
LOCAL.GENERATED_WITH_V2 = None
|
||||
|
||||
if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
|
||||
# Objects dict is already generated for the proper TF version:
|
||||
# do nothing.
|
||||
return
|
||||
|
||||
LOCAL.ALL_OBJECTS = {}
|
||||
LOCAL.GENERATED_WITH_V2 = tf2.enabled()
|
||||
|
||||
base_cls = base_layer.Layer
|
||||
generic_utils.populate_dict_with_module_objects(
|
||||
LOCAL.ALL_OBJECTS,
|
||||
ALL_MODULES,
|
||||
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
|
||||
|
||||
# Overwrite certain V1 objects with V2 versions
|
||||
if tf2.enabled():
|
||||
generic_utils.populate_dict_with_module_objects(
|
||||
LOCAL.ALL_OBJECTS,
|
||||
ALL_V2_MODULES,
|
||||
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
|
||||
|
||||
# These deserialization aliases are added for backward compatibility,
|
||||
# as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2"
|
||||
# were used as class name for v1 and v2 version of BatchNormalization,
|
||||
# respectively. Here we explicitly convert them to their canonical names.
|
||||
LOCAL.ALL_OBJECTS['BatchNormalizationV1'] = normalization.BatchNormalization
|
||||
LOCAL.ALL_OBJECTS[
|
||||
'BatchNormalizationV2'] = normalization_v2.BatchNormalization
|
||||
|
||||
# Prevent circular dependencies.
|
||||
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.feature_column import dense_features # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top
|
||||
|
||||
LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
|
||||
LOCAL.ALL_OBJECTS['Network'] = models.Network
|
||||
LOCAL.ALL_OBJECTS['Model'] = models.Model
|
||||
LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential
|
||||
LOCAL.ALL_OBJECTS['LinearModel'] = LinearModel
|
||||
LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel
|
||||
LOCAL.ALL_OBJECTS['DenseFeatures'] = dense_features.DenseFeatures
|
||||
LOCAL.ALL_OBJECTS['SequenceFeatures'] = sfc.SequenceFeatures
|
||||
|
||||
|
||||
@keras_export('keras.layers.serialize')
|
||||
def serialize(layer):
|
||||
return serialize_keras_object(layer)
|
||||
return generic_utils.serialize_keras_object(layer)
|
||||
|
||||
|
||||
@keras_export('keras.layers.deserialize')
|
||||
@ -80,30 +154,9 @@ def deserialize(config, custom_objects=None):
|
||||
Returns:
|
||||
Layer instance (may be Model, Sequential, Network, Layer...)
|
||||
"""
|
||||
# Prevent circular dependencies.
|
||||
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.feature_column import dense_features # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top
|
||||
|
||||
globs = globals() # All layers.
|
||||
globs['Network'] = models.Network
|
||||
globs['Model'] = models.Model
|
||||
globs['Sequential'] = models.Sequential
|
||||
globs['LinearModel'] = LinearModel
|
||||
globs['WideDeepModel'] = WideDeepModel
|
||||
|
||||
# Prevent circular dependencies with FeatureColumn serialization.
|
||||
globs['DenseFeatures'] = dense_features.DenseFeatures
|
||||
globs['SequenceFeatures'] = sfc.SequenceFeatures
|
||||
|
||||
layer_class_name = config['class_name']
|
||||
if layer_class_name in _DESERIALIZATION_TABLE:
|
||||
config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name]
|
||||
|
||||
return deserialize_keras_object(
|
||||
populate_deserializable_objects()
|
||||
return generic_utils.deserialize_keras_object(
|
||||
config,
|
||||
module_objects=globs,
|
||||
module_objects=LOCAL.ALL_OBJECTS,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='layer')
|
||||
|
@ -124,12 +124,6 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase):
|
||||
layer = batchnorm_layer(
|
||||
momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2')
|
||||
config = keras.layers.serialize(layer)
|
||||
# To simulate if BatchNormalizationV1 or BatchNormalizationV2 appears in the
|
||||
# saved model.
|
||||
if batchnorm_layer is batchnorm_v1.BatchNormalization:
|
||||
config['class_name'] = 'BatchNormalizationV1'
|
||||
else:
|
||||
config['class_name'] = 'BatchNormalizationV2'
|
||||
new_layer = keras.layers.deserialize(config)
|
||||
self.assertEqual(new_layer.momentum, 0.9)
|
||||
if tf2.enabled():
|
||||
|
@ -364,7 +364,8 @@ def deserialize_keras_object(identifier,
|
||||
else:
|
||||
obj = module_objects.get(object_name)
|
||||
if obj is None:
|
||||
raise ValueError('Unknown ' + printable_module_name + ':' + object_name)
|
||||
raise ValueError(
|
||||
'Unknown ' + printable_module_name + ': ' + object_name)
|
||||
# Classes passed by name are instantiated with no args, functions are
|
||||
# returned as-is.
|
||||
if tf_inspect.isclass(obj):
|
||||
@ -783,6 +784,13 @@ def is_default(method):
|
||||
return getattr(method, '_is_default', False)
|
||||
|
||||
|
||||
def populate_dict_with_module_objects(target_dict, modules, obj_filter):
|
||||
for module in modules:
|
||||
for name in dir(module):
|
||||
obj = getattr(module, name)
|
||||
if obj_filter(obj):
|
||||
target_dict[name] = obj
|
||||
|
||||
# Aliases
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user