Save build input shape to SavedModel metadata, and use the shape information when loading layers.
This CL also adds a custom JSON encoder/decoder to handle the build_input_shape, which can be a TensorShape or tuple. As with cl/293453611, this resolves loading built-in preprocessing layers but does not address custom preprocessing layers. PiperOrigin-RevId: 296997111 Change-Id: Ic5831471f0d823a9eed7a00b28f6a8a8f9b991b5
This commit is contained in:
parent
8987c83721
commit
0225ad8ca1
@ -383,6 +383,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
self._auto_track_sub_layers = True
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
@base_layer_utils.default
|
||||
def build(self, input_shape):
|
||||
"""Creates the variables of the layer (optional, for subclass implementers).
|
||||
|
||||
|
@ -178,6 +178,7 @@ class Layer(base_layer.Layer):
|
||||
# Indicates whether `build` needs to be called upon layer call, to create
|
||||
# the layer's weights.
|
||||
self.built = False
|
||||
self._build_input_shape = None
|
||||
# Provides information about which inputs are compatible with the layer.
|
||||
self._input_spec = None
|
||||
self.supports_masking = False
|
||||
@ -252,6 +253,8 @@ class Layer(base_layer.Layer):
|
||||
# might want to turn it off, like Sequential model.
|
||||
self._auto_track_sub_layers = True
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
@base_layer_utils.default
|
||||
def build(self, input_shape):
|
||||
"""Creates the variables of the layer (optional, for subclass implementers).
|
||||
|
||||
@ -266,6 +269,8 @@ class Layer(base_layer.Layer):
|
||||
`TensorShape` if the layer expects a list of inputs
|
||||
(one instance per input).
|
||||
"""
|
||||
if not hasattr(self.build, '_is_default'):
|
||||
self._build_input_shape = input_shape
|
||||
self.built = True
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
|
@ -19,6 +19,7 @@ py_library(
|
||||
"save.py",
|
||||
"saved_model/base_serialization.py",
|
||||
"saved_model/constants.py",
|
||||
"saved_model/json_utils.py",
|
||||
"saved_model/layer_serialization.py",
|
||||
"saved_model/load.py",
|
||||
"saved_model/model_serialization.py",
|
||||
@ -164,6 +165,7 @@ tf_py_test(
|
||||
size = "medium",
|
||||
srcs = ["saved_model/revive_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
@ -171,3 +173,15 @@ tf_py_test(
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "json_utils_test",
|
||||
size = "small",
|
||||
srcs = ["saved_model/json_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -19,12 +19,10 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import json
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
@ -53,9 +51,7 @@ class SavedModelSaver(object):
|
||||
"""
|
||||
# TODO(kathywu): check that serialized JSON can be loaded (e.g., if an
|
||||
# object is in the python property)
|
||||
return json.dumps(
|
||||
self.python_properties,
|
||||
default=serialization.get_json_type)
|
||||
return json_utils.Encoder().encode(self.python_properties)
|
||||
|
||||
def list_extra_dependencies_for_serialization(self, serialization_cache):
|
||||
"""Lists extra dependencies to serialize to SavedModel.
|
||||
|
69
tensorflow/python/keras/saving/saved_model/json_utils.py
Normal file
69
tensorflow/python/keras/saving/saved_model/json_utils.py
Normal file
@ -0,0 +1,69 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utils for creating and loading the Layer metadata for SavedModel.
|
||||
|
||||
These are required to retain the original format of the build input shape, since
|
||||
layers and models may have different build behaviors depending on if the shape
|
||||
is a list, tuple, or TensorShape. For example, Network.build() will create
|
||||
separate inputs if the given input_shape is a list, and will create a single
|
||||
input if the given shape is a tuple.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
|
||||
class Encoder(json.JSONEncoder):
|
||||
"""JSON encoder and decoder that handles TensorShapes and tuples."""
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, tensor_shape.TensorShape):
|
||||
items = obj.as_list() if obj.rank is not None else None
|
||||
return {'class_name': 'TensorShape', 'items': items}
|
||||
return serialization.get_json_type(obj)
|
||||
|
||||
def encode(self, obj):
|
||||
return super(Encoder, self).encode(_encode_tuple(obj))
|
||||
|
||||
|
||||
def _encode_tuple(x):
|
||||
if isinstance(x, tuple):
|
||||
return {'class_name': '__tuple__',
|
||||
'items': tuple(_encode_tuple(i) for i in x)}
|
||||
elif isinstance(x, list):
|
||||
return [_encode_tuple(i) for i in x]
|
||||
elif isinstance(x, dict):
|
||||
return {key: _encode_tuple(value) for key, value in x.items()}
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def decode(json_string):
|
||||
return json.loads(json_string, object_hook=_decode_helper)
|
||||
|
||||
|
||||
def _decode_helper(obj):
|
||||
if isinstance(obj, dict) and 'class_name' in obj:
|
||||
if obj['class_name'] == 'TensorShape':
|
||||
return tensor_shape.TensorShape(obj['items'])
|
||||
elif obj['class_name'] == '__tuple__':
|
||||
return tuple(_decode_helper(i) for i in obj['items'])
|
||||
return obj
|
@ -0,0 +1,55 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=protected-access
|
||||
"""Tests the JSON encoder and decoder."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class JsonUtilsTest(test.TestCase):
|
||||
|
||||
def test_encode_decode_tensor_shape(self):
|
||||
metadata = {
|
||||
'key1': tensor_shape.TensorShape(None),
|
||||
'key2': [tensor_shape.TensorShape([None]),
|
||||
tensor_shape.TensorShape([3, None, 5])]}
|
||||
string = json_utils.Encoder().encode(metadata)
|
||||
loaded = json_utils.decode(string)
|
||||
|
||||
self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
|
||||
self.assertAllEqual(loaded['key1'].rank, None)
|
||||
self.assertAllEqual(loaded['key2'][0].as_list(), [None])
|
||||
self.assertAllEqual(loaded['key2'][1].as_list(), [3, None, 5])
|
||||
|
||||
def test_encode_decode_tuple(self):
|
||||
metadata = {
|
||||
'key1': (3, 5),
|
||||
'key2': [(1, (3, 4)), (1,)]}
|
||||
string = json_utils.Encoder().encode(metadata)
|
||||
loaded = json_utils.decode(string)
|
||||
|
||||
self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
|
||||
self.assertAllEqual(loaded['key1'], (3, 5))
|
||||
self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1,)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -68,6 +68,8 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||||
hasattr(self.obj.activity_regularizer, 'get_config')):
|
||||
metadata['activity_regularizer'] = generic_utils.serialize_keras_object(
|
||||
self.obj.activity_regularizer)
|
||||
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
|
||||
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
|
||||
return metadata
|
||||
|
||||
def objects_to_serialize(self, serialization_cache):
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
@ -29,6 +28,7 @@ from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.saving.saved_model import constants
|
||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.keras.saving.saved_model import utils
|
||||
from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
@ -40,6 +40,7 @@ from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking.tracking import delete_tracking
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
|
||||
# To avoid circular dependencies between keras/engine and keras/saving,
|
||||
@ -164,6 +165,8 @@ class KerasObjectLoader(tf_load.Loader):
|
||||
# records all nodes that were generated directly/indirectly from the config,
|
||||
# so that they do not get recreated multiple times.
|
||||
self._nodes_recreated_from_config = {}
|
||||
self._all_nodes_recreated_from_config = (
|
||||
object_identity.ObjectIdentityWeakSet())
|
||||
# Store all node ids that have already been traversed when tracking nodes
|
||||
# that were recreated from the config.
|
||||
self._traversed_nodes_from_config = []
|
||||
@ -227,30 +230,36 @@ class KerasObjectLoader(tf_load.Loader):
|
||||
return
|
||||
self._traversed_nodes_from_config.append(node_id)
|
||||
obj._maybe_initialize_trackable()
|
||||
|
||||
for reference in proto.children:
|
||||
obj_child = obj._lookup_dependency(reference.local_name)
|
||||
setter = setattr
|
||||
child_id = reference.node_id
|
||||
child_proto = self._proto.nodes[child_id]
|
||||
|
||||
if not isinstance(obj_child, trackable.Trackable):
|
||||
continue
|
||||
if obj_child._object_identifier in revived_types.registered_identifiers():
|
||||
setter = lambda *unused: None
|
||||
if (child_proto.user_object.identifier in
|
||||
revived_types.registered_identifiers()):
|
||||
setter = revived_types.get_setter(child_proto.user_object)
|
||||
elif obj_child._object_identifier in KERAS_OBJECT_IDENTIFIERS:
|
||||
metadata = self._proto.nodes[reference.node_id].user_object.metadata
|
||||
setter = _revive_setter
|
||||
_add_serialized_attributes(obj_child, json.loads(metadata))
|
||||
else:
|
||||
setter = setattr
|
||||
# pylint: enable=protected-access
|
||||
if (reference.node_id in self._nodes_recreated_from_config and
|
||||
self._nodes_recreated_from_config[reference.node_id][0] is not
|
||||
obj_child):
|
||||
|
||||
if (child_id in self._nodes_recreated_from_config and
|
||||
self._nodes_recreated_from_config[child_id][0] is not obj_child):
|
||||
# This means that the same trackable object is referenced by two
|
||||
# different objects that were recreated from the config.
|
||||
logging.warn('Looks like there is an object (perhaps variable or layer)'
|
||||
' that is shared between different layers/models. This '
|
||||
'may cause issues when training the model. Object: {}'
|
||||
.format(obj_child))
|
||||
self._nodes_recreated_from_config[reference.node_id] = obj_child, setter
|
||||
'may cause issues when restoring the variable values.'
|
||||
'Object: {}'.format(obj_child))
|
||||
self._nodes_recreated_from_config[child_id] = (
|
||||
obj_child, self._config_node_setter(setter))
|
||||
self._all_nodes_recreated_from_config.add(obj_child)
|
||||
self._add_children_recreated_from_config(
|
||||
obj_child, self._proto.nodes[reference.node_id], reference.node_id)
|
||||
obj_child, child_proto, child_id)
|
||||
|
||||
def _load_layers(self):
|
||||
layers = {}
|
||||
@ -262,19 +271,35 @@ class KerasObjectLoader(tf_load.Loader):
|
||||
|
||||
def _load_layer(self, proto, node_id):
|
||||
"""Load a single layer from a SavedUserObject proto."""
|
||||
metadata = json_utils.decode(proto.metadata)
|
||||
|
||||
# If node was already created
|
||||
if node_id in self._nodes_recreated_from_config:
|
||||
node, setter = self._nodes_recreated_from_config[node_id]
|
||||
|
||||
self._try_build_layer(node, node_id, metadata.get('build_input_shape'))
|
||||
|
||||
# Revive setter requires the object to have a `_serialized_attributes`
|
||||
# property. Add it here.
|
||||
_maybe_add_serialized_attributes(node, metadata)
|
||||
|
||||
config = metadata.get('config')
|
||||
if _is_graph_network(node) and generic_utils.validate_config(config):
|
||||
self.model_layer_dependencies[node_id] = (
|
||||
node, self._get_child_layer_node_ids(node_id, node.name))
|
||||
return node, setter
|
||||
|
||||
# Detect whether this object can be revived from the config. If not, then
|
||||
# revive from the SavedModel instead.
|
||||
metadata = json.loads(proto.metadata)
|
||||
obj, setter = self._revive_from_config(metadata, node_id)
|
||||
if obj is None:
|
||||
obj, setter = revive_custom_object(proto.identifier, metadata)
|
||||
|
||||
if setter == _revive_setter:
|
||||
# Add an attribute that stores the extra functions/objects saved in the
|
||||
# SavedModel. Most of these functions/objects are ignored, but some are
|
||||
# used later in the loading process (e.g. the list of regularization
|
||||
# losses, or the training config of compiled models).
|
||||
_add_serialized_attributes(obj, metadata)
|
||||
# Add an attribute that stores the extra functions/objects saved in the
|
||||
# SavedModel. Most of these functions/objects are ignored, but some are
|
||||
# used later in the loading process (e.g. the list of regularization
|
||||
# losses, or the training config of compiled models).
|
||||
_maybe_add_serialized_attributes(obj, metadata)
|
||||
return obj, setter
|
||||
|
||||
def _revive_from_config(self, metadata, node_id):
|
||||
@ -284,8 +309,9 @@ class KerasObjectLoader(tf_load.Loader):
|
||||
if obj is None:
|
||||
return None, None
|
||||
|
||||
setter = _revive_setter
|
||||
setter = self._config_node_setter(_revive_setter)
|
||||
self._nodes_recreated_from_config[node_id] = obj, setter
|
||||
self._all_nodes_recreated_from_config.add(obj)
|
||||
self._add_children_recreated_from_config(
|
||||
obj, self._proto.nodes[node_id], node_id)
|
||||
return obj, setter
|
||||
@ -300,9 +326,8 @@ class KerasObjectLoader(tf_load.Loader):
|
||||
model_is_functional_or_sequential = (
|
||||
metadata.get('is_graph_network', False) or
|
||||
metadata['class_name'] == 'Sequential')
|
||||
if (config is None or
|
||||
generic_utils.LAYER_UNDEFINED_CONFIG_KEY in config or
|
||||
not model_is_functional_or_sequential):
|
||||
if not (generic_utils.validate_config(config) and
|
||||
model_is_functional_or_sequential):
|
||||
return None # Revive as custom model.
|
||||
|
||||
# Revive functional and sequential models as blank model objects for now (
|
||||
@ -329,7 +354,7 @@ class KerasObjectLoader(tf_load.Loader):
|
||||
# found.
|
||||
class_name = metadata.get('class_name')
|
||||
config = metadata.get('config')
|
||||
if config is None or generic_utils.LAYER_UNDEFINED_CONFIG_KEY in config:
|
||||
if not generic_utils.validate_config(config):
|
||||
return None
|
||||
|
||||
try:
|
||||
@ -348,16 +373,31 @@ class KerasObjectLoader(tf_load.Loader):
|
||||
obj._set_dtype_policy(metadata['dtype'])
|
||||
# pylint: enable=protected-access
|
||||
|
||||
input_shape = None
|
||||
if not isinstance(obj, input_layer.InputLayer):
|
||||
input_shape = self._infer_inputs(node_id, convert_to_shapes=True)
|
||||
if input_shape is None:
|
||||
return None
|
||||
obj.build(input_shape)
|
||||
obj.built = True
|
||||
build_input_shape = metadata.get('build_input_shape')
|
||||
built = self._try_build_layer(obj, node_id, build_input_shape)
|
||||
|
||||
if not built:
|
||||
# If the layer cannot be built, revive a custom layer instead.
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
||||
def _try_build_layer(self, obj, node_id, build_input_shape):
|
||||
"""Attempts to build the layer."""
|
||||
if obj.built or hasattr(obj.build, '_is_default'):
|
||||
obj.built = True
|
||||
return True
|
||||
|
||||
if build_input_shape is None:
|
||||
build_input_shape = self._infer_inputs(node_id, convert_to_shapes=True)
|
||||
|
||||
if build_input_shape is not None:
|
||||
obj.build(build_input_shape)
|
||||
base_layer.Layer.build(obj, build_input_shape)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _load_edges(self):
|
||||
"""Add edges for all nodes that are not waiting on initialization."""
|
||||
for node_id, proto in enumerate(self._proto.nodes):
|
||||
@ -432,8 +472,8 @@ class KerasObjectLoader(tf_load.Loader):
|
||||
.format(uninitialized_model_names))
|
||||
|
||||
def _reconstruct_model(self, model_id, model, layers):
|
||||
config = (
|
||||
json.loads(self._proto.nodes[model_id].user_object.metadata)['config'])
|
||||
config = json_utils.decode(
|
||||
self._proto.nodes[model_id].user_object.metadata)['config']
|
||||
if isinstance(model, models_lib.Sequential):
|
||||
if not isinstance(layers[0], input_layer.InputLayer):
|
||||
if 'batch_input_shape' in config['layers'][0]['config']:
|
||||
@ -502,6 +542,14 @@ class KerasObjectLoader(tf_load.Loader):
|
||||
else:
|
||||
return inputs
|
||||
|
||||
def _config_node_setter(self, setter):
|
||||
"""Creates edges for nodes that are recreated from config."""
|
||||
def setattr_wrapper(obj, name, value):
|
||||
# Avoid overwriting attributes of objects recreated from the config.
|
||||
if obj._lookup_dependency(name) is None: # pylint: disable=protected-access
|
||||
setter(obj, name, value)
|
||||
return setattr_wrapper
|
||||
|
||||
|
||||
def _finalize_saved_model_layers(layers):
|
||||
"""Runs the final steps of loading Keras Layers from SavedModel."""
|
||||
@ -626,8 +674,9 @@ class RevivedLayer(object):
|
||||
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
|
||||
# pylint:disable=protected-access
|
||||
revived_obj._expects_training_arg = metadata['expects_training_arg']
|
||||
if metadata.get('config') is not None:
|
||||
revived_obj._config = metadata['config']
|
||||
config = metadata.get('config)')
|
||||
if generic_utils.validate_config(config):
|
||||
revived_obj._config = config
|
||||
if metadata.get('input_spec') is not None:
|
||||
revived_obj.input_spec = recursively_deserialize_keras_object(
|
||||
metadata['input_spec'],
|
||||
@ -747,8 +796,9 @@ class RevivedNetwork(RevivedLayer):
|
||||
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
|
||||
# pylint:disable=protected-access
|
||||
revived_obj._expects_training_arg = metadata['expects_training_arg']
|
||||
if metadata.get('config') is not None:
|
||||
revived_obj._config = metadata['config']
|
||||
config = metadata.get('config')
|
||||
if generic_utils.validate_config(config):
|
||||
revived_obj._config = config
|
||||
|
||||
if metadata.get('activity_regularizer') is not None:
|
||||
revived_obj.activity_regularizer = regularizers.deserialize(
|
||||
@ -769,12 +819,13 @@ def _set_network_attributes_from_metadata(revived_obj):
|
||||
# pylint:enable=protected-access
|
||||
|
||||
|
||||
def _add_serialized_attributes(layer, metadata):
|
||||
def _maybe_add_serialized_attributes(layer, metadata):
|
||||
# Store attributes revived from SerializedAttributes in a un-tracked
|
||||
# dictionary. The attributes are the ones listed in CommonEndpoints or
|
||||
# "keras_api" for keras-specific attributes.
|
||||
with trackable.no_automatic_dependency_tracking_scope(layer):
|
||||
layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access
|
||||
if not hasattr(layer, '_serialized_attributes'):
|
||||
with trackable.no_automatic_dependency_tracking_scope(layer):
|
||||
layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _get_keras_attr(layer):
|
||||
|
@ -50,15 +50,18 @@ class SubclassedModelNoConfig(keras.Model):
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.shared = CustomLayerNoConfig(a, b)
|
||||
self.all_layers = [
|
||||
self.all_layers = []
|
||||
|
||||
def build(self, input_shape):
|
||||
self.all_layers.extend([
|
||||
self.shared,
|
||||
CustomLayerWithConfig(a + 1, b + 2),
|
||||
CustomLayerNoConfig(a + 3, b + 4),
|
||||
CustomLayerWithConfig(self.a + 1, self.b + 2),
|
||||
CustomLayerNoConfig(self.a + 3, self.b + 4),
|
||||
keras.Sequential([
|
||||
# TODO(b/145029112): Bug with losses when there are shared layers.
|
||||
# self.shared, <-- Enable when bug is fixed.
|
||||
CustomLayerNoConfig(a + 5, b + 6)
|
||||
])]
|
||||
CustomLayerNoConfig(self.a + 5, self.b + 6)])])
|
||||
super(SubclassedModelNoConfig, self).build(input_shape)
|
||||
|
||||
def call(self, inputs):
|
||||
x = inputs
|
||||
|
@ -44,7 +44,7 @@ _GLOBAL_CUSTOM_NAMES = {}
|
||||
_SKIP_FAILED_SERIALIZATION = False
|
||||
# If a layer does not have a defined config, then the returned config will be a
|
||||
# dictionary with the below key.
|
||||
LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config'
|
||||
_LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config'
|
||||
|
||||
|
||||
@keras_export('keras.utils.CustomObjectScope')
|
||||
@ -271,7 +271,7 @@ def serialize_keras_object(instance):
|
||||
except NotImplementedError as e:
|
||||
if _SKIP_FAILED_SERIALIZATION:
|
||||
return serialize_keras_class_and_config(
|
||||
name, {LAYER_UNDEFINED_CONFIG_KEY: True})
|
||||
name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
|
||||
raise e
|
||||
serialization_config = {}
|
||||
for key, item in config.items():
|
||||
@ -790,3 +790,8 @@ def validate_kwargs(kwargs,
|
||||
for kwarg in kwargs:
|
||||
if kwarg not in allowed_kwargs:
|
||||
raise TypeError(error_message, kwarg)
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
"""Determines whether config appears to be a valid layer config."""
|
||||
return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config
|
||||
|
@ -169,3 +169,13 @@ def deserialize(proto):
|
||||
|
||||
def registered_identifiers():
|
||||
return _REVIVED_TYPE_REGISTRY.keys()
|
||||
|
||||
|
||||
def get_setter(proto):
|
||||
_, type_registrations = _REVIVED_TYPE_REGISTRY.get(
|
||||
proto.identifier, (None, None))
|
||||
if type_registrations is not None:
|
||||
for type_registration in type_registrations:
|
||||
if type_registration.should_load(proto):
|
||||
return type_registration.setter
|
||||
return None
|
||||
|
Loading…
Reference in New Issue
Block a user