Save build input shape to SavedModel metadata, and use the shape information when loading layers.
This CL also adds a custom JSON encoder/decoder to handle the build_input_shape, which can be a TensorShape or tuple. As with cl/293453611, this resolves loading built-in preprocessing layers but does not address custom preprocessing layers. PiperOrigin-RevId: 296997111 Change-Id: Ic5831471f0d823a9eed7a00b28f6a8a8f9b991b5
This commit is contained in:
parent
8987c83721
commit
0225ad8ca1
@ -383,6 +383,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
self._auto_track_sub_layers = True
|
self._auto_track_sub_layers = True
|
||||||
|
|
||||||
@trackable.no_automatic_dependency_tracking
|
@trackable.no_automatic_dependency_tracking
|
||||||
|
@base_layer_utils.default
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
"""Creates the variables of the layer (optional, for subclass implementers).
|
"""Creates the variables of the layer (optional, for subclass implementers).
|
||||||
|
|
||||||
|
@ -178,6 +178,7 @@ class Layer(base_layer.Layer):
|
|||||||
# Indicates whether `build` needs to be called upon layer call, to create
|
# Indicates whether `build` needs to be called upon layer call, to create
|
||||||
# the layer's weights.
|
# the layer's weights.
|
||||||
self.built = False
|
self.built = False
|
||||||
|
self._build_input_shape = None
|
||||||
# Provides information about which inputs are compatible with the layer.
|
# Provides information about which inputs are compatible with the layer.
|
||||||
self._input_spec = None
|
self._input_spec = None
|
||||||
self.supports_masking = False
|
self.supports_masking = False
|
||||||
@ -252,6 +253,8 @@ class Layer(base_layer.Layer):
|
|||||||
# might want to turn it off, like Sequential model.
|
# might want to turn it off, like Sequential model.
|
||||||
self._auto_track_sub_layers = True
|
self._auto_track_sub_layers = True
|
||||||
|
|
||||||
|
@trackable.no_automatic_dependency_tracking
|
||||||
|
@base_layer_utils.default
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
"""Creates the variables of the layer (optional, for subclass implementers).
|
"""Creates the variables of the layer (optional, for subclass implementers).
|
||||||
|
|
||||||
@ -266,6 +269,8 @@ class Layer(base_layer.Layer):
|
|||||||
`TensorShape` if the layer expects a list of inputs
|
`TensorShape` if the layer expects a list of inputs
|
||||||
(one instance per input).
|
(one instance per input).
|
||||||
"""
|
"""
|
||||||
|
if not hasattr(self.build, '_is_default'):
|
||||||
|
self._build_input_shape = input_shape
|
||||||
self.built = True
|
self.built = True
|
||||||
|
|
||||||
@doc_controls.for_subclass_implementers
|
@doc_controls.for_subclass_implementers
|
||||||
|
@ -19,6 +19,7 @@ py_library(
|
|||||||
"save.py",
|
"save.py",
|
||||||
"saved_model/base_serialization.py",
|
"saved_model/base_serialization.py",
|
||||||
"saved_model/constants.py",
|
"saved_model/constants.py",
|
||||||
|
"saved_model/json_utils.py",
|
||||||
"saved_model/layer_serialization.py",
|
"saved_model/layer_serialization.py",
|
||||||
"saved_model/load.py",
|
"saved_model/load.py",
|
||||||
"saved_model/model_serialization.py",
|
"saved_model/model_serialization.py",
|
||||||
@ -164,6 +165,7 @@ tf_py_test(
|
|||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["saved_model/revive_test.py"],
|
srcs = ["saved_model/revive_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
|
shard_count = 4,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python/keras",
|
"//tensorflow/python/keras",
|
||||||
@ -171,3 +173,15 @@ tf_py_test(
|
|||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "json_utils_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["saved_model/json_utils_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -19,12 +19,10 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import json
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
from tensorflow.python.util import serialization
|
|
||||||
|
|
||||||
|
|
||||||
@six.add_metaclass(abc.ABCMeta)
|
@six.add_metaclass(abc.ABCMeta)
|
||||||
@ -53,9 +51,7 @@ class SavedModelSaver(object):
|
|||||||
"""
|
"""
|
||||||
# TODO(kathywu): check that serialized JSON can be loaded (e.g., if an
|
# TODO(kathywu): check that serialized JSON can be loaded (e.g., if an
|
||||||
# object is in the python property)
|
# object is in the python property)
|
||||||
return json.dumps(
|
return json_utils.Encoder().encode(self.python_properties)
|
||||||
self.python_properties,
|
|
||||||
default=serialization.get_json_type)
|
|
||||||
|
|
||||||
def list_extra_dependencies_for_serialization(self, serialization_cache):
|
def list_extra_dependencies_for_serialization(self, serialization_cache):
|
||||||
"""Lists extra dependencies to serialize to SavedModel.
|
"""Lists extra dependencies to serialize to SavedModel.
|
||||||
|
69
tensorflow/python/keras/saving/saved_model/json_utils.py
Normal file
69
tensorflow/python/keras/saving/saved_model/json_utils.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utils for creating and loading the Layer metadata for SavedModel.
|
||||||
|
|
||||||
|
These are required to retain the original format of the build input shape, since
|
||||||
|
layers and models may have different build behaviors depending on if the shape
|
||||||
|
is a list, tuple, or TensorShape. For example, Network.build() will create
|
||||||
|
separate inputs if the given input_shape is a list, and will create a single
|
||||||
|
input if the given shape is a tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.util import serialization
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(json.JSONEncoder):
|
||||||
|
"""JSON encoder and decoder that handles TensorShapes and tuples."""
|
||||||
|
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, tensor_shape.TensorShape):
|
||||||
|
items = obj.as_list() if obj.rank is not None else None
|
||||||
|
return {'class_name': 'TensorShape', 'items': items}
|
||||||
|
return serialization.get_json_type(obj)
|
||||||
|
|
||||||
|
def encode(self, obj):
|
||||||
|
return super(Encoder, self).encode(_encode_tuple(obj))
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_tuple(x):
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
return {'class_name': '__tuple__',
|
||||||
|
'items': tuple(_encode_tuple(i) for i in x)}
|
||||||
|
elif isinstance(x, list):
|
||||||
|
return [_encode_tuple(i) for i in x]
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
return {key: _encode_tuple(value) for key, value in x.items()}
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def decode(json_string):
|
||||||
|
return json.loads(json_string, object_hook=_decode_helper)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_helper(obj):
|
||||||
|
if isinstance(obj, dict) and 'class_name' in obj:
|
||||||
|
if obj['class_name'] == 'TensorShape':
|
||||||
|
return tensor_shape.TensorShape(obj['items'])
|
||||||
|
elif obj['class_name'] == '__tuple__':
|
||||||
|
return tuple(_decode_helper(i) for i in obj['items'])
|
||||||
|
return obj
|
@ -0,0 +1,55 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
"""Tests the JSON encoder and decoder."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class JsonUtilsTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_encode_decode_tensor_shape(self):
|
||||||
|
metadata = {
|
||||||
|
'key1': tensor_shape.TensorShape(None),
|
||||||
|
'key2': [tensor_shape.TensorShape([None]),
|
||||||
|
tensor_shape.TensorShape([3, None, 5])]}
|
||||||
|
string = json_utils.Encoder().encode(metadata)
|
||||||
|
loaded = json_utils.decode(string)
|
||||||
|
|
||||||
|
self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
|
||||||
|
self.assertAllEqual(loaded['key1'].rank, None)
|
||||||
|
self.assertAllEqual(loaded['key2'][0].as_list(), [None])
|
||||||
|
self.assertAllEqual(loaded['key2'][1].as_list(), [3, None, 5])
|
||||||
|
|
||||||
|
def test_encode_decode_tuple(self):
|
||||||
|
metadata = {
|
||||||
|
'key1': (3, 5),
|
||||||
|
'key2': [(1, (3, 4)), (1,)]}
|
||||||
|
string = json_utils.Encoder().encode(metadata)
|
||||||
|
loaded = json_utils.decode(string)
|
||||||
|
|
||||||
|
self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
|
||||||
|
self.assertAllEqual(loaded['key1'], (3, 5))
|
||||||
|
self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1,)])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -68,6 +68,8 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
|||||||
hasattr(self.obj.activity_regularizer, 'get_config')):
|
hasattr(self.obj.activity_regularizer, 'get_config')):
|
||||||
metadata['activity_regularizer'] = generic_utils.serialize_keras_object(
|
metadata['activity_regularizer'] = generic_utils.serialize_keras_object(
|
||||||
self.obj.activity_regularizer)
|
self.obj.activity_regularizer)
|
||||||
|
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
|
||||||
|
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def objects_to_serialize(self, serialization_cache):
|
def objects_to_serialize(self, serialization_cache):
|
||||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -29,6 +28,7 @@ from tensorflow.python.keras import regularizers
|
|||||||
from tensorflow.python.keras.engine import input_spec
|
from tensorflow.python.keras.engine import input_spec
|
||||||
from tensorflow.python.keras.saving import saving_utils
|
from tensorflow.python.keras.saving import saving_utils
|
||||||
from tensorflow.python.keras.saving.saved_model import constants
|
from tensorflow.python.keras.saving.saved_model import constants
|
||||||
|
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||||
from tensorflow.python.keras.saving.saved_model import utils
|
from tensorflow.python.keras.saving.saved_model import utils
|
||||||
from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints
|
from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
@ -40,6 +40,7 @@ from tensorflow.python.training.tracking import base as trackable
|
|||||||
from tensorflow.python.training.tracking.tracking import delete_tracking
|
from tensorflow.python.training.tracking.tracking import delete_tracking
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
from tensorflow.python.util import object_identity
|
||||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||||
|
|
||||||
# To avoid circular dependencies between keras/engine and keras/saving,
|
# To avoid circular dependencies between keras/engine and keras/saving,
|
||||||
@ -164,6 +165,8 @@ class KerasObjectLoader(tf_load.Loader):
|
|||||||
# records all nodes that were generated directly/indirectly from the config,
|
# records all nodes that were generated directly/indirectly from the config,
|
||||||
# so that they do not get recreated multiple times.
|
# so that they do not get recreated multiple times.
|
||||||
self._nodes_recreated_from_config = {}
|
self._nodes_recreated_from_config = {}
|
||||||
|
self._all_nodes_recreated_from_config = (
|
||||||
|
object_identity.ObjectIdentityWeakSet())
|
||||||
# Store all node ids that have already been traversed when tracking nodes
|
# Store all node ids that have already been traversed when tracking nodes
|
||||||
# that were recreated from the config.
|
# that were recreated from the config.
|
||||||
self._traversed_nodes_from_config = []
|
self._traversed_nodes_from_config = []
|
||||||
@ -227,30 +230,36 @@ class KerasObjectLoader(tf_load.Loader):
|
|||||||
return
|
return
|
||||||
self._traversed_nodes_from_config.append(node_id)
|
self._traversed_nodes_from_config.append(node_id)
|
||||||
obj._maybe_initialize_trackable()
|
obj._maybe_initialize_trackable()
|
||||||
|
|
||||||
for reference in proto.children:
|
for reference in proto.children:
|
||||||
obj_child = obj._lookup_dependency(reference.local_name)
|
obj_child = obj._lookup_dependency(reference.local_name)
|
||||||
setter = setattr
|
child_id = reference.node_id
|
||||||
|
child_proto = self._proto.nodes[child_id]
|
||||||
|
|
||||||
if not isinstance(obj_child, trackable.Trackable):
|
if not isinstance(obj_child, trackable.Trackable):
|
||||||
continue
|
continue
|
||||||
if obj_child._object_identifier in revived_types.registered_identifiers():
|
if (child_proto.user_object.identifier in
|
||||||
setter = lambda *unused: None
|
revived_types.registered_identifiers()):
|
||||||
|
setter = revived_types.get_setter(child_proto.user_object)
|
||||||
elif obj_child._object_identifier in KERAS_OBJECT_IDENTIFIERS:
|
elif obj_child._object_identifier in KERAS_OBJECT_IDENTIFIERS:
|
||||||
metadata = self._proto.nodes[reference.node_id].user_object.metadata
|
|
||||||
setter = _revive_setter
|
setter = _revive_setter
|
||||||
_add_serialized_attributes(obj_child, json.loads(metadata))
|
else:
|
||||||
|
setter = setattr
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
if (reference.node_id in self._nodes_recreated_from_config and
|
|
||||||
self._nodes_recreated_from_config[reference.node_id][0] is not
|
if (child_id in self._nodes_recreated_from_config and
|
||||||
obj_child):
|
self._nodes_recreated_from_config[child_id][0] is not obj_child):
|
||||||
# This means that the same trackable object is referenced by two
|
# This means that the same trackable object is referenced by two
|
||||||
# different objects that were recreated from the config.
|
# different objects that were recreated from the config.
|
||||||
logging.warn('Looks like there is an object (perhaps variable or layer)'
|
logging.warn('Looks like there is an object (perhaps variable or layer)'
|
||||||
' that is shared between different layers/models. This '
|
' that is shared between different layers/models. This '
|
||||||
'may cause issues when training the model. Object: {}'
|
'may cause issues when restoring the variable values.'
|
||||||
.format(obj_child))
|
'Object: {}'.format(obj_child))
|
||||||
self._nodes_recreated_from_config[reference.node_id] = obj_child, setter
|
self._nodes_recreated_from_config[child_id] = (
|
||||||
|
obj_child, self._config_node_setter(setter))
|
||||||
|
self._all_nodes_recreated_from_config.add(obj_child)
|
||||||
self._add_children_recreated_from_config(
|
self._add_children_recreated_from_config(
|
||||||
obj_child, self._proto.nodes[reference.node_id], reference.node_id)
|
obj_child, child_proto, child_id)
|
||||||
|
|
||||||
def _load_layers(self):
|
def _load_layers(self):
|
||||||
layers = {}
|
layers = {}
|
||||||
@ -262,19 +271,35 @@ class KerasObjectLoader(tf_load.Loader):
|
|||||||
|
|
||||||
def _load_layer(self, proto, node_id):
|
def _load_layer(self, proto, node_id):
|
||||||
"""Load a single layer from a SavedUserObject proto."""
|
"""Load a single layer from a SavedUserObject proto."""
|
||||||
|
metadata = json_utils.decode(proto.metadata)
|
||||||
|
|
||||||
|
# If node was already created
|
||||||
|
if node_id in self._nodes_recreated_from_config:
|
||||||
|
node, setter = self._nodes_recreated_from_config[node_id]
|
||||||
|
|
||||||
|
self._try_build_layer(node, node_id, metadata.get('build_input_shape'))
|
||||||
|
|
||||||
|
# Revive setter requires the object to have a `_serialized_attributes`
|
||||||
|
# property. Add it here.
|
||||||
|
_maybe_add_serialized_attributes(node, metadata)
|
||||||
|
|
||||||
|
config = metadata.get('config')
|
||||||
|
if _is_graph_network(node) and generic_utils.validate_config(config):
|
||||||
|
self.model_layer_dependencies[node_id] = (
|
||||||
|
node, self._get_child_layer_node_ids(node_id, node.name))
|
||||||
|
return node, setter
|
||||||
|
|
||||||
# Detect whether this object can be revived from the config. If not, then
|
# Detect whether this object can be revived from the config. If not, then
|
||||||
# revive from the SavedModel instead.
|
# revive from the SavedModel instead.
|
||||||
metadata = json.loads(proto.metadata)
|
|
||||||
obj, setter = self._revive_from_config(metadata, node_id)
|
obj, setter = self._revive_from_config(metadata, node_id)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
obj, setter = revive_custom_object(proto.identifier, metadata)
|
obj, setter = revive_custom_object(proto.identifier, metadata)
|
||||||
|
|
||||||
if setter == _revive_setter:
|
# Add an attribute that stores the extra functions/objects saved in the
|
||||||
# Add an attribute that stores the extra functions/objects saved in the
|
# SavedModel. Most of these functions/objects are ignored, but some are
|
||||||
# SavedModel. Most of these functions/objects are ignored, but some are
|
# used later in the loading process (e.g. the list of regularization
|
||||||
# used later in the loading process (e.g. the list of regularization
|
# losses, or the training config of compiled models).
|
||||||
# losses, or the training config of compiled models).
|
_maybe_add_serialized_attributes(obj, metadata)
|
||||||
_add_serialized_attributes(obj, metadata)
|
|
||||||
return obj, setter
|
return obj, setter
|
||||||
|
|
||||||
def _revive_from_config(self, metadata, node_id):
|
def _revive_from_config(self, metadata, node_id):
|
||||||
@ -284,8 +309,9 @@ class KerasObjectLoader(tf_load.Loader):
|
|||||||
if obj is None:
|
if obj is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
setter = _revive_setter
|
setter = self._config_node_setter(_revive_setter)
|
||||||
self._nodes_recreated_from_config[node_id] = obj, setter
|
self._nodes_recreated_from_config[node_id] = obj, setter
|
||||||
|
self._all_nodes_recreated_from_config.add(obj)
|
||||||
self._add_children_recreated_from_config(
|
self._add_children_recreated_from_config(
|
||||||
obj, self._proto.nodes[node_id], node_id)
|
obj, self._proto.nodes[node_id], node_id)
|
||||||
return obj, setter
|
return obj, setter
|
||||||
@ -300,9 +326,8 @@ class KerasObjectLoader(tf_load.Loader):
|
|||||||
model_is_functional_or_sequential = (
|
model_is_functional_or_sequential = (
|
||||||
metadata.get('is_graph_network', False) or
|
metadata.get('is_graph_network', False) or
|
||||||
metadata['class_name'] == 'Sequential')
|
metadata['class_name'] == 'Sequential')
|
||||||
if (config is None or
|
if not (generic_utils.validate_config(config) and
|
||||||
generic_utils.LAYER_UNDEFINED_CONFIG_KEY in config or
|
model_is_functional_or_sequential):
|
||||||
not model_is_functional_or_sequential):
|
|
||||||
return None # Revive as custom model.
|
return None # Revive as custom model.
|
||||||
|
|
||||||
# Revive functional and sequential models as blank model objects for now (
|
# Revive functional and sequential models as blank model objects for now (
|
||||||
@ -329,7 +354,7 @@ class KerasObjectLoader(tf_load.Loader):
|
|||||||
# found.
|
# found.
|
||||||
class_name = metadata.get('class_name')
|
class_name = metadata.get('class_name')
|
||||||
config = metadata.get('config')
|
config = metadata.get('config')
|
||||||
if config is None or generic_utils.LAYER_UNDEFINED_CONFIG_KEY in config:
|
if not generic_utils.validate_config(config):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -348,16 +373,31 @@ class KerasObjectLoader(tf_load.Loader):
|
|||||||
obj._set_dtype_policy(metadata['dtype'])
|
obj._set_dtype_policy(metadata['dtype'])
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
input_shape = None
|
build_input_shape = metadata.get('build_input_shape')
|
||||||
if not isinstance(obj, input_layer.InputLayer):
|
built = self._try_build_layer(obj, node_id, build_input_shape)
|
||||||
input_shape = self._infer_inputs(node_id, convert_to_shapes=True)
|
|
||||||
if input_shape is None:
|
if not built:
|
||||||
return None
|
# If the layer cannot be built, revive a custom layer instead.
|
||||||
obj.build(input_shape)
|
return None
|
||||||
obj.built = True
|
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
def _try_build_layer(self, obj, node_id, build_input_shape):
|
||||||
|
"""Attempts to build the layer."""
|
||||||
|
if obj.built or hasattr(obj.build, '_is_default'):
|
||||||
|
obj.built = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
if build_input_shape is None:
|
||||||
|
build_input_shape = self._infer_inputs(node_id, convert_to_shapes=True)
|
||||||
|
|
||||||
|
if build_input_shape is not None:
|
||||||
|
obj.build(build_input_shape)
|
||||||
|
base_layer.Layer.build(obj, build_input_shape)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def _load_edges(self):
|
def _load_edges(self):
|
||||||
"""Add edges for all nodes that are not waiting on initialization."""
|
"""Add edges for all nodes that are not waiting on initialization."""
|
||||||
for node_id, proto in enumerate(self._proto.nodes):
|
for node_id, proto in enumerate(self._proto.nodes):
|
||||||
@ -432,8 +472,8 @@ class KerasObjectLoader(tf_load.Loader):
|
|||||||
.format(uninitialized_model_names))
|
.format(uninitialized_model_names))
|
||||||
|
|
||||||
def _reconstruct_model(self, model_id, model, layers):
|
def _reconstruct_model(self, model_id, model, layers):
|
||||||
config = (
|
config = json_utils.decode(
|
||||||
json.loads(self._proto.nodes[model_id].user_object.metadata)['config'])
|
self._proto.nodes[model_id].user_object.metadata)['config']
|
||||||
if isinstance(model, models_lib.Sequential):
|
if isinstance(model, models_lib.Sequential):
|
||||||
if not isinstance(layers[0], input_layer.InputLayer):
|
if not isinstance(layers[0], input_layer.InputLayer):
|
||||||
if 'batch_input_shape' in config['layers'][0]['config']:
|
if 'batch_input_shape' in config['layers'][0]['config']:
|
||||||
@ -502,6 +542,14 @@ class KerasObjectLoader(tf_load.Loader):
|
|||||||
else:
|
else:
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
def _config_node_setter(self, setter):
|
||||||
|
"""Creates edges for nodes that are recreated from config."""
|
||||||
|
def setattr_wrapper(obj, name, value):
|
||||||
|
# Avoid overwriting attributes of objects recreated from the config.
|
||||||
|
if obj._lookup_dependency(name) is None: # pylint: disable=protected-access
|
||||||
|
setter(obj, name, value)
|
||||||
|
return setattr_wrapper
|
||||||
|
|
||||||
|
|
||||||
def _finalize_saved_model_layers(layers):
|
def _finalize_saved_model_layers(layers):
|
||||||
"""Runs the final steps of loading Keras Layers from SavedModel."""
|
"""Runs the final steps of loading Keras Layers from SavedModel."""
|
||||||
@ -626,8 +674,9 @@ class RevivedLayer(object):
|
|||||||
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
|
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
|
||||||
# pylint:disable=protected-access
|
# pylint:disable=protected-access
|
||||||
revived_obj._expects_training_arg = metadata['expects_training_arg']
|
revived_obj._expects_training_arg = metadata['expects_training_arg']
|
||||||
if metadata.get('config') is not None:
|
config = metadata.get('config)')
|
||||||
revived_obj._config = metadata['config']
|
if generic_utils.validate_config(config):
|
||||||
|
revived_obj._config = config
|
||||||
if metadata.get('input_spec') is not None:
|
if metadata.get('input_spec') is not None:
|
||||||
revived_obj.input_spec = recursively_deserialize_keras_object(
|
revived_obj.input_spec = recursively_deserialize_keras_object(
|
||||||
metadata['input_spec'],
|
metadata['input_spec'],
|
||||||
@ -747,8 +796,9 @@ class RevivedNetwork(RevivedLayer):
|
|||||||
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
|
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
|
||||||
# pylint:disable=protected-access
|
# pylint:disable=protected-access
|
||||||
revived_obj._expects_training_arg = metadata['expects_training_arg']
|
revived_obj._expects_training_arg = metadata['expects_training_arg']
|
||||||
if metadata.get('config') is not None:
|
config = metadata.get('config')
|
||||||
revived_obj._config = metadata['config']
|
if generic_utils.validate_config(config):
|
||||||
|
revived_obj._config = config
|
||||||
|
|
||||||
if metadata.get('activity_regularizer') is not None:
|
if metadata.get('activity_regularizer') is not None:
|
||||||
revived_obj.activity_regularizer = regularizers.deserialize(
|
revived_obj.activity_regularizer = regularizers.deserialize(
|
||||||
@ -769,12 +819,13 @@ def _set_network_attributes_from_metadata(revived_obj):
|
|||||||
# pylint:enable=protected-access
|
# pylint:enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def _add_serialized_attributes(layer, metadata):
|
def _maybe_add_serialized_attributes(layer, metadata):
|
||||||
# Store attributes revived from SerializedAttributes in a un-tracked
|
# Store attributes revived from SerializedAttributes in a un-tracked
|
||||||
# dictionary. The attributes are the ones listed in CommonEndpoints or
|
# dictionary. The attributes are the ones listed in CommonEndpoints or
|
||||||
# "keras_api" for keras-specific attributes.
|
# "keras_api" for keras-specific attributes.
|
||||||
with trackable.no_automatic_dependency_tracking_scope(layer):
|
if not hasattr(layer, '_serialized_attributes'):
|
||||||
layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access
|
with trackable.no_automatic_dependency_tracking_scope(layer):
|
||||||
|
layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def _get_keras_attr(layer):
|
def _get_keras_attr(layer):
|
||||||
|
@ -50,15 +50,18 @@ class SubclassedModelNoConfig(keras.Model):
|
|||||||
self.a = a
|
self.a = a
|
||||||
self.b = b
|
self.b = b
|
||||||
self.shared = CustomLayerNoConfig(a, b)
|
self.shared = CustomLayerNoConfig(a, b)
|
||||||
self.all_layers = [
|
self.all_layers = []
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
self.all_layers.extend([
|
||||||
self.shared,
|
self.shared,
|
||||||
CustomLayerWithConfig(a + 1, b + 2),
|
CustomLayerWithConfig(self.a + 1, self.b + 2),
|
||||||
CustomLayerNoConfig(a + 3, b + 4),
|
CustomLayerNoConfig(self.a + 3, self.b + 4),
|
||||||
keras.Sequential([
|
keras.Sequential([
|
||||||
# TODO(b/145029112): Bug with losses when there are shared layers.
|
# TODO(b/145029112): Bug with losses when there are shared layers.
|
||||||
# self.shared, <-- Enable when bug is fixed.
|
# self.shared, <-- Enable when bug is fixed.
|
||||||
CustomLayerNoConfig(a + 5, b + 6)
|
CustomLayerNoConfig(self.a + 5, self.b + 6)])])
|
||||||
])]
|
super(SubclassedModelNoConfig, self).build(input_shape)
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
x = inputs
|
x = inputs
|
||||||
|
@ -44,7 +44,7 @@ _GLOBAL_CUSTOM_NAMES = {}
|
|||||||
_SKIP_FAILED_SERIALIZATION = False
|
_SKIP_FAILED_SERIALIZATION = False
|
||||||
# If a layer does not have a defined config, then the returned config will be a
|
# If a layer does not have a defined config, then the returned config will be a
|
||||||
# dictionary with the below key.
|
# dictionary with the below key.
|
||||||
LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config'
|
_LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config'
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.utils.CustomObjectScope')
|
@keras_export('keras.utils.CustomObjectScope')
|
||||||
@ -271,7 +271,7 @@ def serialize_keras_object(instance):
|
|||||||
except NotImplementedError as e:
|
except NotImplementedError as e:
|
||||||
if _SKIP_FAILED_SERIALIZATION:
|
if _SKIP_FAILED_SERIALIZATION:
|
||||||
return serialize_keras_class_and_config(
|
return serialize_keras_class_and_config(
|
||||||
name, {LAYER_UNDEFINED_CONFIG_KEY: True})
|
name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
|
||||||
raise e
|
raise e
|
||||||
serialization_config = {}
|
serialization_config = {}
|
||||||
for key, item in config.items():
|
for key, item in config.items():
|
||||||
@ -790,3 +790,8 @@ def validate_kwargs(kwargs,
|
|||||||
for kwarg in kwargs:
|
for kwarg in kwargs:
|
||||||
if kwarg not in allowed_kwargs:
|
if kwarg not in allowed_kwargs:
|
||||||
raise TypeError(error_message, kwarg)
|
raise TypeError(error_message, kwarg)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_config(config):
|
||||||
|
"""Determines whether config appears to be a valid layer config."""
|
||||||
|
return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config
|
||||||
|
@ -169,3 +169,13 @@ def deserialize(proto):
|
|||||||
|
|
||||||
def registered_identifiers():
|
def registered_identifiers():
|
||||||
return _REVIVED_TYPE_REGISTRY.keys()
|
return _REVIVED_TYPE_REGISTRY.keys()
|
||||||
|
|
||||||
|
|
||||||
|
def get_setter(proto):
|
||||||
|
_, type_registrations = _REVIVED_TYPE_REGISTRY.get(
|
||||||
|
proto.identifier, (None, None))
|
||||||
|
if type_registrations is not None:
|
||||||
|
for type_registration in type_registrations:
|
||||||
|
if type_registration.should_load(proto):
|
||||||
|
return type_registration.setter
|
||||||
|
return None
|
||||||
|
Loading…
Reference in New Issue
Block a user