Make Keras deserialization aware of enable/disable_v2_behavior switches.
PiperOrigin-RevId: 303355879 Change-Id: I3b8f7ae7db4ed1185585c6d56ed445cd4c61b6ca
This commit is contained in:
parent
cd59b293b5
commit
6766a8ee3e
@ -538,11 +538,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
if initializer is None:
|
if initializer is None:
|
||||||
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
|
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
|
||||||
if dtype.is_floating:
|
if dtype.is_floating:
|
||||||
initializer = initializers.glorot_uniform()
|
initializer = initializers.get('glorot_uniform')
|
||||||
# If dtype is DT_INT/DT_UINT, provide a default value `zero`
|
# If dtype is DT_INT/DT_UINT, provide a default value `zero`
|
||||||
# If dtype is DT_BOOL, provide a default value `FALSE`
|
# If dtype is DT_BOOL, provide a default value `FALSE`
|
||||||
elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
|
elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
|
||||||
initializer = initializers.zeros()
|
initializer = initializers.get('zeros')
|
||||||
# NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
|
# NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
|
||||||
else:
|
else:
|
||||||
raise ValueError('An initializer for variable %s of type %s is required'
|
raise ValueError('An initializer for variable %s of type %s is required'
|
||||||
|
@ -402,7 +402,7 @@ class Layer(base_layer.Layer):
|
|||||||
if initializer is None:
|
if initializer is None:
|
||||||
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
|
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
|
||||||
if dtype.is_floating:
|
if dtype.is_floating:
|
||||||
initializer = initializers.glorot_uniform()
|
initializer = initializers.get('glorot_uniform')
|
||||||
# If dtype is DT_INT/DT_UINT, provide a default value `zero`
|
# If dtype is DT_INT/DT_UINT, provide a default value `zero`
|
||||||
# If dtype is DT_BOOL, provide a default value `FALSE`
|
# If dtype is DT_BOOL, provide a default value `FALSE`
|
||||||
elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
|
elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
|
||||||
|
@ -14,114 +14,134 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Keras initializer serialization / deserialization.
|
"""Keras initializer serialization / deserialization.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=unused-import
|
|
||||||
# pylint: disable=line-too-long
|
|
||||||
# pylint: disable=g-import-not-at-top
|
|
||||||
# pylint: disable=g-bad-import-order
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import threading
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
from tensorflow.python.keras.initializers import initializers_v1
|
||||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
from tensorflow.python.keras.initializers import initializers_v2
|
||||||
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.util import tf_inspect as inspect
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
# These imports are brought in so that keras.initializers.deserialize
|
|
||||||
# has them available in module_objects.
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import Constant as ConstantV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import GlorotNormal as GlorotNormalV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import GlorotUniform as GlorotUniformV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import HeNormal as HeNormalV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import HeUniform as HeUniformV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import Identity as IdentityV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import Initializer
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import LecunNormal as LecunNormalV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import LecunUniform as LecunUniformV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import Ones as OnesV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import Orthogonal as OrthogonalV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import RandomNormal as RandomNormalV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import RandomUniform as RandomUniformV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import TruncatedNormal as TruncatedNormalV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import VarianceScaling as VarianceScalingV2
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v2 import Zeros as ZerosV2
|
|
||||||
|
|
||||||
if tf2.enabled():
|
# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
|
||||||
Constant = ConstantV2
|
# thread-local to avoid concurrent mutations.
|
||||||
GlorotNormal = GlorotNormalV2
|
LOCAL = threading.local()
|
||||||
GlorotUniform = GlorotUniformV2
|
|
||||||
HeNormal = HeNormalV2
|
|
||||||
HeUniform = HeUniformV2
|
|
||||||
Identity = IdentityV2
|
|
||||||
LecunNormal = LecunNormalV2
|
|
||||||
LecunUniform = LecunUniformV2
|
|
||||||
Ones = OnesV2
|
|
||||||
Orthogonal = OrthogonalV2
|
|
||||||
RandomNormal = RandomNormalV2
|
|
||||||
RandomUniform = RandomUniformV2
|
|
||||||
TruncatedNormal = TruncatedNormalV2
|
|
||||||
VarianceScaling = VarianceScalingV2
|
|
||||||
Zeros = ZerosV2
|
|
||||||
else:
|
|
||||||
from tensorflow.python.ops.init_ops import Constant
|
|
||||||
from tensorflow.python.ops.init_ops import GlorotNormal
|
|
||||||
from tensorflow.python.ops.init_ops import GlorotUniform
|
|
||||||
from tensorflow.python.ops.init_ops import Identity
|
|
||||||
from tensorflow.python.ops.init_ops import Ones
|
|
||||||
from tensorflow.python.ops.init_ops import Orthogonal
|
|
||||||
from tensorflow.python.ops.init_ops import VarianceScaling
|
|
||||||
from tensorflow.python.ops.init_ops import Zeros
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v1 import HeNormal
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v1 import HeUniform
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v1 import LecunNormal
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v1 import LecunUniform
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v1 import RandomNormal
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v1 import RandomUniform
|
|
||||||
from tensorflow.python.keras.initializers.initializers_v1 import TruncatedNormal
|
|
||||||
|
|
||||||
|
|
||||||
# Compatibility aliases
|
def populate_deserializable_objects():
|
||||||
glorot_normal = GlorotNormal
|
"""Populates dict ALL_OBJECTS with every built-in initializer.
|
||||||
glorot_uniform = GlorotUniform
|
"""
|
||||||
he_normal = HeNormal
|
global LOCAL
|
||||||
he_uniform = HeUniform
|
if not hasattr(LOCAL, 'ALL_OBJECTS'):
|
||||||
lecun_normal = LecunNormal
|
LOCAL.ALL_OBJECTS = {}
|
||||||
lecun_uniform = LecunUniform
|
LOCAL.GENERATED_WITH_V2 = None
|
||||||
zero = zeros = Zeros
|
|
||||||
one = ones = Ones
|
|
||||||
constant = Constant
|
|
||||||
uniform = random_uniform = RandomUniform
|
|
||||||
normal = random_normal = RandomNormal
|
|
||||||
truncated_normal = TruncatedNormal
|
|
||||||
identity = Identity
|
|
||||||
orthogonal = Orthogonal
|
|
||||||
|
|
||||||
# For unit tests
|
if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
|
||||||
glorot_normalV2 = GlorotNormalV2
|
# Objects dict is already generated for the proper TF version:
|
||||||
glorot_uniformV2 = GlorotUniformV2
|
# do nothing.
|
||||||
he_normalV2 = HeNormalV2
|
return
|
||||||
he_uniformV2 = HeUniformV2
|
|
||||||
lecun_normalV2 = LecunNormalV2
|
LOCAL.ALL_OBJECTS = {}
|
||||||
lecun_uniformV2 = LecunUniformV2
|
LOCAL.GENERATED_WITH_V2 = tf2.enabled()
|
||||||
|
|
||||||
|
# Compatibility aliases (need to exist in both V1 and V2).
|
||||||
|
LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant
|
||||||
|
LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal
|
||||||
|
LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform
|
||||||
|
LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal
|
||||||
|
LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform
|
||||||
|
LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity
|
||||||
|
LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal
|
||||||
|
LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform
|
||||||
|
LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones
|
||||||
|
LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal
|
||||||
|
LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal
|
||||||
|
LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform
|
||||||
|
LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal
|
||||||
|
LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling
|
||||||
|
LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros
|
||||||
|
|
||||||
|
# Out of an abundance of caution we also include these aliases that have
|
||||||
|
# a non-zero probability of having been included in saved configs in the past.
|
||||||
|
LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal
|
||||||
|
LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform
|
||||||
|
LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal
|
||||||
|
LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform
|
||||||
|
LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal
|
||||||
|
LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform
|
||||||
|
|
||||||
|
if tf2.enabled():
|
||||||
|
# For V2, entries are generated automatically based on the content of
|
||||||
|
# initializers_v2.py.
|
||||||
|
v2_objs = {}
|
||||||
|
base_cls = initializers_v2.Initializer
|
||||||
|
generic_utils.populate_dict_with_module_objects(
|
||||||
|
v2_objs,
|
||||||
|
[initializers_v2],
|
||||||
|
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
|
||||||
|
for key, value in v2_objs.items():
|
||||||
|
LOCAL.ALL_OBJECTS[key] = value
|
||||||
|
# Functional aliases.
|
||||||
|
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
|
||||||
|
else:
|
||||||
|
# V1 initializers.
|
||||||
|
v1_objs = {
|
||||||
|
'Constant': init_ops.Constant,
|
||||||
|
'GlorotNormal': init_ops.GlorotNormal,
|
||||||
|
'GlorotUniform': init_ops.GlorotUniform,
|
||||||
|
'Identity': init_ops.Identity,
|
||||||
|
'Ones': init_ops.Ones,
|
||||||
|
'Orthogonal': init_ops.Orthogonal,
|
||||||
|
'VarianceScaling': init_ops.VarianceScaling,
|
||||||
|
'Zeros': init_ops.Zeros,
|
||||||
|
'HeNormal': initializers_v1.HeNormal,
|
||||||
|
'HeUniform': initializers_v1.HeUniform,
|
||||||
|
'LecunNormal': initializers_v1.LecunNormal,
|
||||||
|
'LecunUniform': initializers_v1.LecunUniform,
|
||||||
|
'RandomNormal': initializers_v1.RandomNormal,
|
||||||
|
'RandomUniform': initializers_v1.RandomUniform,
|
||||||
|
'TruncatedNormal': initializers_v1.TruncatedNormal,
|
||||||
|
}
|
||||||
|
for key, value in v1_objs.items():
|
||||||
|
LOCAL.ALL_OBJECTS[key] = value
|
||||||
|
# Functional aliases.
|
||||||
|
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
|
||||||
|
|
||||||
|
# More compatibility aliases.
|
||||||
|
LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal']
|
||||||
|
LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform']
|
||||||
|
LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones']
|
||||||
|
LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros']
|
||||||
|
|
||||||
|
|
||||||
|
# For backwards compatibility, we populate this file with the objects
|
||||||
|
# from ALL_OBJECTS. We make no guarantees as to whether these objects will
|
||||||
|
# using their correct version.
|
||||||
|
populate_deserializable_objects()
|
||||||
|
globals().update(LOCAL.ALL_OBJECTS)
|
||||||
|
|
||||||
# Utility functions
|
# Utility functions
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.initializers.serialize')
|
@keras_export('keras.initializers.serialize')
|
||||||
def serialize(initializer):
|
def serialize(initializer):
|
||||||
return serialize_keras_object(initializer)
|
return generic_utils.serialize_keras_object(initializer)
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.initializers.deserialize')
|
@keras_export('keras.initializers.deserialize')
|
||||||
def deserialize(config, custom_objects=None):
|
def deserialize(config, custom_objects=None):
|
||||||
"""Return an `Initializer` object from its config."""
|
"""Return an `Initializer` object from its config."""
|
||||||
module_objects = globals()
|
populate_deserializable_objects()
|
||||||
return deserialize_keras_object(
|
return generic_utils.deserialize_keras_object(
|
||||||
config,
|
config,
|
||||||
module_objects=module_objects,
|
module_objects=LOCAL.ALL_OBJECTS,
|
||||||
custom_objects=custom_objects,
|
custom_objects=custom_objects,
|
||||||
printable_module_name='initializer')
|
printable_module_name='initializer')
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ class KerasInitializersTest(test.TestCase):
|
|||||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||||
std = np.sqrt(1. / fan_in)
|
std = np.sqrt(1. / fan_in)
|
||||||
self._runner(
|
self._runner(
|
||||||
initializers.lecun_uniformV2(seed=123),
|
initializers.LecunUniformV2(seed=123),
|
||||||
tensor_shape,
|
tensor_shape,
|
||||||
target_mean=0.,
|
target_mean=0.,
|
||||||
target_std=std)
|
target_std=std)
|
||||||
@ -113,7 +113,7 @@ class KerasInitializersTest(test.TestCase):
|
|||||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||||
std = np.sqrt(2. / fan_in)
|
std = np.sqrt(2. / fan_in)
|
||||||
self._runner(
|
self._runner(
|
||||||
initializers.he_uniformV2(seed=123),
|
initializers.HeUniformV2(seed=123),
|
||||||
tensor_shape,
|
tensor_shape,
|
||||||
target_mean=0.,
|
target_mean=0.,
|
||||||
target_std=std)
|
target_std=std)
|
||||||
@ -124,7 +124,7 @@ class KerasInitializersTest(test.TestCase):
|
|||||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||||
std = np.sqrt(1. / fan_in)
|
std = np.sqrt(1. / fan_in)
|
||||||
self._runner(
|
self._runner(
|
||||||
initializers.lecun_normalV2(seed=123),
|
initializers.LecunNormalV2(seed=123),
|
||||||
tensor_shape,
|
tensor_shape,
|
||||||
target_mean=0.,
|
target_mean=0.,
|
||||||
target_std=std)
|
target_std=std)
|
||||||
@ -146,7 +146,7 @@ class KerasInitializersTest(test.TestCase):
|
|||||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||||
std = np.sqrt(2. / fan_in)
|
std = np.sqrt(2. / fan_in)
|
||||||
self._runner(
|
self._runner(
|
||||||
initializers.he_normalV2(seed=123),
|
initializers.HeNormalV2(seed=123),
|
||||||
tensor_shape,
|
tensor_shape,
|
||||||
target_mean=0.,
|
target_mean=0.,
|
||||||
target_std=std)
|
target_std=std)
|
||||||
|
@ -569,7 +569,7 @@ class ConvLSTM2DCell(DropoutRNNCellMixin, Layer):
|
|||||||
def bias_initializer(_, *args, **kwargs):
|
def bias_initializer(_, *args, **kwargs):
|
||||||
return K.concatenate([
|
return K.concatenate([
|
||||||
self.bias_initializer((self.filters,), *args, **kwargs),
|
self.bias_initializer((self.filters,), *args, **kwargs),
|
||||||
initializers.Ones()((self.filters,), *args, **kwargs),
|
initializers.get('ones')((self.filters,), *args, **kwargs),
|
||||||
self.bias_initializer((self.filters * 2,), *args, **kwargs),
|
self.bias_initializer((self.filters * 2,), *args, **kwargs),
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
|
@ -2340,7 +2340,7 @@ class LSTMCell(DropoutRNNCellMixin, Layer):
|
|||||||
def bias_initializer(_, *args, **kwargs):
|
def bias_initializer(_, *args, **kwargs):
|
||||||
return K.concatenate([
|
return K.concatenate([
|
||||||
self.bias_initializer((self.units,), *args, **kwargs),
|
self.bias_initializer((self.units,), *args, **kwargs),
|
||||||
initializers.Ones()((self.units,), *args, **kwargs),
|
initializers.get('ones')((self.units,), *args, **kwargs),
|
||||||
self.bias_initializer((self.units * 2,), *args, **kwargs),
|
self.bias_initializer((self.units * 2,), *args, **kwargs),
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
|
@ -21,51 +21,125 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import threading
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.keras.engine.base_layer import AddLoss
|
from tensorflow.python.keras.engine import base_layer
|
||||||
from tensorflow.python.keras.engine.base_layer import AddMetric
|
from tensorflow.python.keras.engine import input_layer
|
||||||
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
|
from tensorflow.python.keras.layers import advanced_activations
|
||||||
from tensorflow.python.keras.engine.input_layer import Input
|
from tensorflow.python.keras.layers import convolutional
|
||||||
from tensorflow.python.keras.engine.input_layer import InputLayer
|
from tensorflow.python.keras.layers import convolutional_recurrent
|
||||||
from tensorflow.python.keras.layers.advanced_activations import *
|
from tensorflow.python.keras.layers import core
|
||||||
from tensorflow.python.keras.layers.convolutional import *
|
from tensorflow.python.keras.layers import cudnn_recurrent
|
||||||
from tensorflow.python.keras.layers.convolutional_recurrent import *
|
from tensorflow.python.keras.layers import dense_attention
|
||||||
from tensorflow.python.keras.layers.core import *
|
from tensorflow.python.keras.layers import embeddings
|
||||||
from tensorflow.python.keras.layers.cudnn_recurrent import *
|
from tensorflow.python.keras.layers import local
|
||||||
from tensorflow.python.keras.layers.dense_attention import *
|
from tensorflow.python.keras.layers import merge
|
||||||
from tensorflow.python.keras.layers.embeddings import *
|
from tensorflow.python.keras.layers import noise
|
||||||
from tensorflow.python.keras.layers.local import *
|
from tensorflow.python.keras.layers import normalization
|
||||||
from tensorflow.python.keras.layers.merge import *
|
from tensorflow.python.keras.layers import normalization_v2
|
||||||
from tensorflow.python.keras.layers.noise import *
|
from tensorflow.python.keras.layers import pooling
|
||||||
from tensorflow.python.keras.layers.normalization import *
|
from tensorflow.python.keras.layers import recurrent
|
||||||
from tensorflow.python.keras.layers.pooling import *
|
from tensorflow.python.keras.layers import recurrent_v2
|
||||||
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import *
|
from tensorflow.python.keras.layers import rnn_cell_wrapper_v2
|
||||||
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import *
|
from tensorflow.python.keras.layers import wrappers
|
||||||
from tensorflow.python.keras.layers.recurrent import *
|
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||||
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import *
|
from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization
|
||||||
from tensorflow.python.keras.layers.wrappers import *
|
from tensorflow.python.keras.layers.preprocessing import normalization_v1 as preprocessing_normalization_v1
|
||||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
from tensorflow.python.util import tf_inspect as inspect
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
if tf2.enabled():
|
|
||||||
from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top
|
|
||||||
from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top
|
|
||||||
from tensorflow.python.keras.layers.preprocessing.normalization import * # pylint: disable=g-import-not-at-top
|
|
||||||
|
|
||||||
# This deserialization table is added for backward compatibility, as in TF 1.13,
|
ALL_MODULES = (
|
||||||
# BatchNormalizationV1 and BatchNormalizationV2 are used as class name for v1
|
base_layer,
|
||||||
# and v2 version of BatchNormalization, respectively. Here we explicitly convert
|
input_layer,
|
||||||
# them to the canonical name in the config of deserialization.
|
advanced_activations,
|
||||||
_DESERIALIZATION_TABLE = {
|
convolutional,
|
||||||
'BatchNormalizationV1': 'BatchNormalization',
|
convolutional_recurrent,
|
||||||
'BatchNormalizationV2': 'BatchNormalization',
|
core,
|
||||||
}
|
cudnn_recurrent,
|
||||||
|
dense_attention,
|
||||||
|
embeddings,
|
||||||
|
local,
|
||||||
|
merge,
|
||||||
|
noise,
|
||||||
|
normalization,
|
||||||
|
pooling,
|
||||||
|
image_preprocessing,
|
||||||
|
preprocessing_normalization_v1,
|
||||||
|
recurrent,
|
||||||
|
wrappers
|
||||||
|
)
|
||||||
|
ALL_V2_MODULES = (
|
||||||
|
rnn_cell_wrapper_v2,
|
||||||
|
normalization_v2,
|
||||||
|
recurrent_v2,
|
||||||
|
preprocessing_normalization
|
||||||
|
)
|
||||||
|
|
||||||
|
# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
|
||||||
|
# thread-local to avoid concurrent mutations.
|
||||||
|
LOCAL = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def populate_deserializable_objects():
|
||||||
|
"""Populates dict ALL_OBJECTS with every built-in layer.
|
||||||
|
"""
|
||||||
|
global LOCAL
|
||||||
|
if not hasattr(LOCAL, 'ALL_OBJECTS'):
|
||||||
|
LOCAL.ALL_OBJECTS = {}
|
||||||
|
LOCAL.GENERATED_WITH_V2 = None
|
||||||
|
|
||||||
|
if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
|
||||||
|
# Objects dict is already generated for the proper TF version:
|
||||||
|
# do nothing.
|
||||||
|
return
|
||||||
|
|
||||||
|
LOCAL.ALL_OBJECTS = {}
|
||||||
|
LOCAL.GENERATED_WITH_V2 = tf2.enabled()
|
||||||
|
|
||||||
|
base_cls = base_layer.Layer
|
||||||
|
generic_utils.populate_dict_with_module_objects(
|
||||||
|
LOCAL.ALL_OBJECTS,
|
||||||
|
ALL_MODULES,
|
||||||
|
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
|
||||||
|
|
||||||
|
# Overwrite certain V1 objects with V2 versions
|
||||||
|
if tf2.enabled():
|
||||||
|
generic_utils.populate_dict_with_module_objects(
|
||||||
|
LOCAL.ALL_OBJECTS,
|
||||||
|
ALL_V2_MODULES,
|
||||||
|
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
|
||||||
|
|
||||||
|
# These deserialization aliases are added for backward compatibility,
|
||||||
|
# as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2"
|
||||||
|
# were used as class name for v1 and v2 version of BatchNormalization,
|
||||||
|
# respectively. Here we explicitly convert them to their canonical names.
|
||||||
|
LOCAL.ALL_OBJECTS['BatchNormalizationV1'] = normalization.BatchNormalization
|
||||||
|
LOCAL.ALL_OBJECTS[
|
||||||
|
'BatchNormalizationV2'] = normalization_v2.BatchNormalization
|
||||||
|
|
||||||
|
# Prevent circular dependencies.
|
||||||
|
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
|
||||||
|
from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top
|
||||||
|
from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top
|
||||||
|
from tensorflow.python.feature_column import dense_features # pylint: disable=g-import-not-at-top
|
||||||
|
from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top
|
||||||
|
|
||||||
|
LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
|
||||||
|
LOCAL.ALL_OBJECTS['Network'] = models.Network
|
||||||
|
LOCAL.ALL_OBJECTS['Model'] = models.Model
|
||||||
|
LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential
|
||||||
|
LOCAL.ALL_OBJECTS['LinearModel'] = LinearModel
|
||||||
|
LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel
|
||||||
|
LOCAL.ALL_OBJECTS['DenseFeatures'] = dense_features.DenseFeatures
|
||||||
|
LOCAL.ALL_OBJECTS['SequenceFeatures'] = sfc.SequenceFeatures
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.layers.serialize')
|
@keras_export('keras.layers.serialize')
|
||||||
def serialize(layer):
|
def serialize(layer):
|
||||||
return serialize_keras_object(layer)
|
return generic_utils.serialize_keras_object(layer)
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.layers.deserialize')
|
@keras_export('keras.layers.deserialize')
|
||||||
@ -80,30 +154,9 @@ def deserialize(config, custom_objects=None):
|
|||||||
Returns:
|
Returns:
|
||||||
Layer instance (may be Model, Sequential, Network, Layer...)
|
Layer instance (may be Model, Sequential, Network, Layer...)
|
||||||
"""
|
"""
|
||||||
# Prevent circular dependencies.
|
populate_deserializable_objects()
|
||||||
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
|
return generic_utils.deserialize_keras_object(
|
||||||
from tensorflow.python.keras.premade.linear import LinearModel # pylint: disable=g-import-not-at-top
|
|
||||||
from tensorflow.python.keras.premade.wide_deep import WideDeepModel # pylint: disable=g-import-not-at-top
|
|
||||||
from tensorflow.python.feature_column import dense_features # pylint: disable=g-import-not-at-top
|
|
||||||
from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top
|
|
||||||
|
|
||||||
globs = globals() # All layers.
|
|
||||||
globs['Network'] = models.Network
|
|
||||||
globs['Model'] = models.Model
|
|
||||||
globs['Sequential'] = models.Sequential
|
|
||||||
globs['LinearModel'] = LinearModel
|
|
||||||
globs['WideDeepModel'] = WideDeepModel
|
|
||||||
|
|
||||||
# Prevent circular dependencies with FeatureColumn serialization.
|
|
||||||
globs['DenseFeatures'] = dense_features.DenseFeatures
|
|
||||||
globs['SequenceFeatures'] = sfc.SequenceFeatures
|
|
||||||
|
|
||||||
layer_class_name = config['class_name']
|
|
||||||
if layer_class_name in _DESERIALIZATION_TABLE:
|
|
||||||
config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name]
|
|
||||||
|
|
||||||
return deserialize_keras_object(
|
|
||||||
config,
|
config,
|
||||||
module_objects=globs,
|
module_objects=LOCAL.ALL_OBJECTS,
|
||||||
custom_objects=custom_objects,
|
custom_objects=custom_objects,
|
||||||
printable_module_name='layer')
|
printable_module_name='layer')
|
||||||
|
@ -124,12 +124,6 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase):
|
|||||||
layer = batchnorm_layer(
|
layer = batchnorm_layer(
|
||||||
momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2')
|
momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2')
|
||||||
config = keras.layers.serialize(layer)
|
config = keras.layers.serialize(layer)
|
||||||
# To simulate if BatchNormalizationV1 or BatchNormalizationV2 appears in the
|
|
||||||
# saved model.
|
|
||||||
if batchnorm_layer is batchnorm_v1.BatchNormalization:
|
|
||||||
config['class_name'] = 'BatchNormalizationV1'
|
|
||||||
else:
|
|
||||||
config['class_name'] = 'BatchNormalizationV2'
|
|
||||||
new_layer = keras.layers.deserialize(config)
|
new_layer = keras.layers.deserialize(config)
|
||||||
self.assertEqual(new_layer.momentum, 0.9)
|
self.assertEqual(new_layer.momentum, 0.9)
|
||||||
if tf2.enabled():
|
if tf2.enabled():
|
||||||
|
@ -364,7 +364,8 @@ def deserialize_keras_object(identifier,
|
|||||||
else:
|
else:
|
||||||
obj = module_objects.get(object_name)
|
obj = module_objects.get(object_name)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
raise ValueError('Unknown ' + printable_module_name + ':' + object_name)
|
raise ValueError(
|
||||||
|
'Unknown ' + printable_module_name + ': ' + object_name)
|
||||||
# Classes passed by name are instantiated with no args, functions are
|
# Classes passed by name are instantiated with no args, functions are
|
||||||
# returned as-is.
|
# returned as-is.
|
||||||
if tf_inspect.isclass(obj):
|
if tf_inspect.isclass(obj):
|
||||||
@ -783,6 +784,13 @@ def is_default(method):
|
|||||||
return getattr(method, '_is_default', False)
|
return getattr(method, '_is_default', False)
|
||||||
|
|
||||||
|
|
||||||
|
def populate_dict_with_module_objects(target_dict, modules, obj_filter):
|
||||||
|
for module in modules:
|
||||||
|
for name in dir(module):
|
||||||
|
obj = getattr(module, name)
|
||||||
|
if obj_filter(obj):
|
||||||
|
target_dict[name] = obj
|
||||||
|
|
||||||
# Aliases
|
# Aliases
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user