Save build input shape to SavedModel metadata, and use the shape information when loading layers.

This CL also adds a custom JSON encoder/decoder to handle the build_input_shape, which can be a TensorShape or tuple.

As with cl/293453611, this resolves loading built-in preprocessing layers but does not address custom preprocessing layers.

PiperOrigin-RevId: 296997111
Change-Id: Ic5831471f0d823a9eed7a00b28f6a8a8f9b991b5
This commit is contained in:
Katherine Wu 2020-02-24 16:18:48 -08:00 committed by TensorFlower Gardener
parent 8987c83721
commit 0225ad8ca1
11 changed files with 265 additions and 54 deletions

View File

@ -383,6 +383,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._auto_track_sub_layers = True
@trackable.no_automatic_dependency_tracking
@base_layer_utils.default
def build(self, input_shape):
"""Creates the variables of the layer (optional, for subclass implementers).

View File

@ -178,6 +178,7 @@ class Layer(base_layer.Layer):
# Indicates whether `build` needs to be called upon layer call, to create
# the layer's weights.
self.built = False
self._build_input_shape = None
# Provides information about which inputs are compatible with the layer.
self._input_spec = None
self.supports_masking = False
@ -252,6 +253,8 @@ class Layer(base_layer.Layer):
# might want to turn it off, like Sequential model.
self._auto_track_sub_layers = True
@trackable.no_automatic_dependency_tracking
@base_layer_utils.default
def build(self, input_shape):
"""Creates the variables of the layer (optional, for subclass implementers).
@ -266,6 +269,8 @@ class Layer(base_layer.Layer):
`TensorShape` if the layer expects a list of inputs
(one instance per input).
"""
if not hasattr(self.build, '_is_default'):
self._build_input_shape = input_shape
self.built = True
@doc_controls.for_subclass_implementers

View File

@ -19,6 +19,7 @@ py_library(
"save.py",
"saved_model/base_serialization.py",
"saved_model/constants.py",
"saved_model/json_utils.py",
"saved_model/layer_serialization.py",
"saved_model/load.py",
"saved_model/model_serialization.py",
@ -164,6 +165,7 @@ tf_py_test(
size = "medium",
srcs = ["saved_model/revive_test.py"],
python_version = "PY3",
shard_count = 4,
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",
@ -171,3 +173,15 @@ tf_py_test(
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "json_utils_test",
size = "small",
srcs = ["saved_model/json_utils_test.py"],
python_version = "PY3",
deps = [
"//tensorflow/python:client_testlib",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -19,12 +19,10 @@ from __future__ import division
from __future__ import print_function
import abc
import json
import six
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import serialization
@six.add_metaclass(abc.ABCMeta)
@ -53,9 +51,7 @@ class SavedModelSaver(object):
"""
# TODO(kathywu): check that serialized JSON can be loaded (e.g., if an
# object is in the python property)
return json.dumps(
self.python_properties,
default=serialization.get_json_type)
return json_utils.Encoder().encode(self.python_properties)
def list_extra_dependencies_for_serialization(self, serialization_cache):
"""Lists extra dependencies to serialize to SavedModel.

View File

@ -0,0 +1,69 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for creating and loading the Layer metadata for SavedModel.
These are required to retain the original format of the build input shape, since
layers and models may have different build behaviors depending on if the shape
is a list, tuple, or TensorShape. For example, Network.build() will create
separate inputs if the given input_shape is a list, and will create a single
input if the given shape is a tuple.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import serialization
class Encoder(json.JSONEncoder):
"""JSON encoder and decoder that handles TensorShapes and tuples."""
def default(self, obj):
if isinstance(obj, tensor_shape.TensorShape):
items = obj.as_list() if obj.rank is not None else None
return {'class_name': 'TensorShape', 'items': items}
return serialization.get_json_type(obj)
def encode(self, obj):
return super(Encoder, self).encode(_encode_tuple(obj))
def _encode_tuple(x):
if isinstance(x, tuple):
return {'class_name': '__tuple__',
'items': tuple(_encode_tuple(i) for i in x)}
elif isinstance(x, list):
return [_encode_tuple(i) for i in x]
elif isinstance(x, dict):
return {key: _encode_tuple(value) for key, value in x.items()}
else:
return x
def decode(json_string):
return json.loads(json_string, object_hook=_decode_helper)
def _decode_helper(obj):
if isinstance(obj, dict) and 'class_name' in obj:
if obj['class_name'] == 'TensorShape':
return tensor_shape.TensorShape(obj['items'])
elif obj['class_name'] == '__tuple__':
return tuple(_decode_helper(i) for i in obj['items'])
return obj

View File

@ -0,0 +1,55 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Tests the JSON encoder and decoder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.platform import test
class JsonUtilsTest(test.TestCase):
def test_encode_decode_tensor_shape(self):
metadata = {
'key1': tensor_shape.TensorShape(None),
'key2': [tensor_shape.TensorShape([None]),
tensor_shape.TensorShape([3, None, 5])]}
string = json_utils.Encoder().encode(metadata)
loaded = json_utils.decode(string)
self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
self.assertAllEqual(loaded['key1'].rank, None)
self.assertAllEqual(loaded['key2'][0].as_list(), [None])
self.assertAllEqual(loaded['key2'][1].as_list(), [3, None, 5])
def test_encode_decode_tuple(self):
metadata = {
'key1': (3, 5),
'key2': [(1, (3, 4)), (1,)]}
string = json_utils.Encoder().encode(metadata)
loaded = json_utils.decode(string)
self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
self.assertAllEqual(loaded['key1'], (3, 5))
self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1,)])
if __name__ == '__main__':
test.main()

View File

@ -68,6 +68,8 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
hasattr(self.obj.activity_regularizer, 'get_config')):
metadata['activity_regularizer'] = generic_utils.serialize_keras_object(
self.obj.activity_regularizer)
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
return metadata
def objects_to_serialize(self, serialization_cache):

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import re
from tensorflow.python.eager import context
@ -29,6 +28,7 @@ from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine import input_spec
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.saving.saved_model import constants
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.keras.saving.saved_model import utils
from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints
from tensorflow.python.keras.utils import generic_utils
@ -40,6 +40,7 @@ from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking.tracking import delete_tracking
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util.lazy_loader import LazyLoader
# To avoid circular dependencies between keras/engine and keras/saving,
@ -164,6 +165,8 @@ class KerasObjectLoader(tf_load.Loader):
# records all nodes that were generated directly/indirectly from the config,
# so that they do not get recreated multiple times.
self._nodes_recreated_from_config = {}
self._all_nodes_recreated_from_config = (
object_identity.ObjectIdentityWeakSet())
# Store all node ids that have already been traversed when tracking nodes
# that were recreated from the config.
self._traversed_nodes_from_config = []
@ -227,30 +230,36 @@ class KerasObjectLoader(tf_load.Loader):
return
self._traversed_nodes_from_config.append(node_id)
obj._maybe_initialize_trackable()
for reference in proto.children:
obj_child = obj._lookup_dependency(reference.local_name)
setter = setattr
child_id = reference.node_id
child_proto = self._proto.nodes[child_id]
if not isinstance(obj_child, trackable.Trackable):
continue
if obj_child._object_identifier in revived_types.registered_identifiers():
setter = lambda *unused: None
if (child_proto.user_object.identifier in
revived_types.registered_identifiers()):
setter = revived_types.get_setter(child_proto.user_object)
elif obj_child._object_identifier in KERAS_OBJECT_IDENTIFIERS:
metadata = self._proto.nodes[reference.node_id].user_object.metadata
setter = _revive_setter
_add_serialized_attributes(obj_child, json.loads(metadata))
else:
setter = setattr
# pylint: enable=protected-access
if (reference.node_id in self._nodes_recreated_from_config and
self._nodes_recreated_from_config[reference.node_id][0] is not
obj_child):
if (child_id in self._nodes_recreated_from_config and
self._nodes_recreated_from_config[child_id][0] is not obj_child):
# This means that the same trackable object is referenced by two
# different objects that were recreated from the config.
logging.warn('Looks like there is an object (perhaps variable or layer)'
' that is shared between different layers/models. This '
'may cause issues when training the model. Object: {}'
.format(obj_child))
self._nodes_recreated_from_config[reference.node_id] = obj_child, setter
'may cause issues when restoring the variable values.'
'Object: {}'.format(obj_child))
self._nodes_recreated_from_config[child_id] = (
obj_child, self._config_node_setter(setter))
self._all_nodes_recreated_from_config.add(obj_child)
self._add_children_recreated_from_config(
obj_child, self._proto.nodes[reference.node_id], reference.node_id)
obj_child, child_proto, child_id)
def _load_layers(self):
layers = {}
@ -262,19 +271,35 @@ class KerasObjectLoader(tf_load.Loader):
def _load_layer(self, proto, node_id):
"""Load a single layer from a SavedUserObject proto."""
metadata = json_utils.decode(proto.metadata)
# If node was already created
if node_id in self._nodes_recreated_from_config:
node, setter = self._nodes_recreated_from_config[node_id]
self._try_build_layer(node, node_id, metadata.get('build_input_shape'))
# Revive setter requires the object to have a `_serialized_attributes`
# property. Add it here.
_maybe_add_serialized_attributes(node, metadata)
config = metadata.get('config')
if _is_graph_network(node) and generic_utils.validate_config(config):
self.model_layer_dependencies[node_id] = (
node, self._get_child_layer_node_ids(node_id, node.name))
return node, setter
# Detect whether this object can be revived from the config. If not, then
# revive from the SavedModel instead.
metadata = json.loads(proto.metadata)
obj, setter = self._revive_from_config(metadata, node_id)
if obj is None:
obj, setter = revive_custom_object(proto.identifier, metadata)
if setter == _revive_setter:
# Add an attribute that stores the extra functions/objects saved in the
# SavedModel. Most of these functions/objects are ignored, but some are
# used later in the loading process (e.g. the list of regularization
# losses, or the training config of compiled models).
_add_serialized_attributes(obj, metadata)
# Add an attribute that stores the extra functions/objects saved in the
# SavedModel. Most of these functions/objects are ignored, but some are
# used later in the loading process (e.g. the list of regularization
# losses, or the training config of compiled models).
_maybe_add_serialized_attributes(obj, metadata)
return obj, setter
def _revive_from_config(self, metadata, node_id):
@ -284,8 +309,9 @@ class KerasObjectLoader(tf_load.Loader):
if obj is None:
return None, None
setter = _revive_setter
setter = self._config_node_setter(_revive_setter)
self._nodes_recreated_from_config[node_id] = obj, setter
self._all_nodes_recreated_from_config.add(obj)
self._add_children_recreated_from_config(
obj, self._proto.nodes[node_id], node_id)
return obj, setter
@ -300,9 +326,8 @@ class KerasObjectLoader(tf_load.Loader):
model_is_functional_or_sequential = (
metadata.get('is_graph_network', False) or
metadata['class_name'] == 'Sequential')
if (config is None or
generic_utils.LAYER_UNDEFINED_CONFIG_KEY in config or
not model_is_functional_or_sequential):
if not (generic_utils.validate_config(config) and
model_is_functional_or_sequential):
return None # Revive as custom model.
# Revive functional and sequential models as blank model objects for now (
@ -329,7 +354,7 @@ class KerasObjectLoader(tf_load.Loader):
# found.
class_name = metadata.get('class_name')
config = metadata.get('config')
if config is None or generic_utils.LAYER_UNDEFINED_CONFIG_KEY in config:
if not generic_utils.validate_config(config):
return None
try:
@ -348,16 +373,31 @@ class KerasObjectLoader(tf_load.Loader):
obj._set_dtype_policy(metadata['dtype'])
# pylint: enable=protected-access
input_shape = None
if not isinstance(obj, input_layer.InputLayer):
input_shape = self._infer_inputs(node_id, convert_to_shapes=True)
if input_shape is None:
return None
obj.build(input_shape)
obj.built = True
build_input_shape = metadata.get('build_input_shape')
built = self._try_build_layer(obj, node_id, build_input_shape)
if not built:
# If the layer cannot be built, revive a custom layer instead.
return None
return obj
def _try_build_layer(self, obj, node_id, build_input_shape):
"""Attempts to build the layer."""
if obj.built or hasattr(obj.build, '_is_default'):
obj.built = True
return True
if build_input_shape is None:
build_input_shape = self._infer_inputs(node_id, convert_to_shapes=True)
if build_input_shape is not None:
obj.build(build_input_shape)
base_layer.Layer.build(obj, build_input_shape)
return True
return False
def _load_edges(self):
"""Add edges for all nodes that are not waiting on initialization."""
for node_id, proto in enumerate(self._proto.nodes):
@ -432,8 +472,8 @@ class KerasObjectLoader(tf_load.Loader):
.format(uninitialized_model_names))
def _reconstruct_model(self, model_id, model, layers):
config = (
json.loads(self._proto.nodes[model_id].user_object.metadata)['config'])
config = json_utils.decode(
self._proto.nodes[model_id].user_object.metadata)['config']
if isinstance(model, models_lib.Sequential):
if not isinstance(layers[0], input_layer.InputLayer):
if 'batch_input_shape' in config['layers'][0]['config']:
@ -502,6 +542,14 @@ class KerasObjectLoader(tf_load.Loader):
else:
return inputs
def _config_node_setter(self, setter):
"""Creates edges for nodes that are recreated from config."""
def setattr_wrapper(obj, name, value):
# Avoid overwriting attributes of objects recreated from the config.
if obj._lookup_dependency(name) is None: # pylint: disable=protected-access
setter(obj, name, value)
return setattr_wrapper
def _finalize_saved_model_layers(layers):
"""Runs the final steps of loading Keras Layers from SavedModel."""
@ -626,8 +674,9 @@ class RevivedLayer(object):
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
# pylint:disable=protected-access
revived_obj._expects_training_arg = metadata['expects_training_arg']
if metadata.get('config') is not None:
revived_obj._config = metadata['config']
config = metadata.get('config)')
if generic_utils.validate_config(config):
revived_obj._config = config
if metadata.get('input_spec') is not None:
revived_obj.input_spec = recursively_deserialize_keras_object(
metadata['input_spec'],
@ -747,8 +796,9 @@ class RevivedNetwork(RevivedLayer):
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
# pylint:disable=protected-access
revived_obj._expects_training_arg = metadata['expects_training_arg']
if metadata.get('config') is not None:
revived_obj._config = metadata['config']
config = metadata.get('config')
if generic_utils.validate_config(config):
revived_obj._config = config
if metadata.get('activity_regularizer') is not None:
revived_obj.activity_regularizer = regularizers.deserialize(
@ -769,12 +819,13 @@ def _set_network_attributes_from_metadata(revived_obj):
# pylint:enable=protected-access
def _add_serialized_attributes(layer, metadata):
def _maybe_add_serialized_attributes(layer, metadata):
# Store attributes revived from SerializedAttributes in a un-tracked
# dictionary. The attributes are the ones listed in CommonEndpoints or
# "keras_api" for keras-specific attributes.
with trackable.no_automatic_dependency_tracking_scope(layer):
layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access
if not hasattr(layer, '_serialized_attributes'):
with trackable.no_automatic_dependency_tracking_scope(layer):
layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access
def _get_keras_attr(layer):

View File

@ -50,15 +50,18 @@ class SubclassedModelNoConfig(keras.Model):
self.a = a
self.b = b
self.shared = CustomLayerNoConfig(a, b)
self.all_layers = [
self.all_layers = []
def build(self, input_shape):
self.all_layers.extend([
self.shared,
CustomLayerWithConfig(a + 1, b + 2),
CustomLayerNoConfig(a + 3, b + 4),
CustomLayerWithConfig(self.a + 1, self.b + 2),
CustomLayerNoConfig(self.a + 3, self.b + 4),
keras.Sequential([
# TODO(b/145029112): Bug with losses when there are shared layers.
# self.shared, <-- Enable when bug is fixed.
CustomLayerNoConfig(a + 5, b + 6)
])]
CustomLayerNoConfig(self.a + 5, self.b + 6)])])
super(SubclassedModelNoConfig, self).build(input_shape)
def call(self, inputs):
x = inputs

View File

@ -44,7 +44,7 @@ _GLOBAL_CUSTOM_NAMES = {}
_SKIP_FAILED_SERIALIZATION = False
# If a layer does not have a defined config, then the returned config will be a
# dictionary with the below key.
LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config'
_LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config'
@keras_export('keras.utils.CustomObjectScope')
@ -271,7 +271,7 @@ def serialize_keras_object(instance):
except NotImplementedError as e:
if _SKIP_FAILED_SERIALIZATION:
return serialize_keras_class_and_config(
name, {LAYER_UNDEFINED_CONFIG_KEY: True})
name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
raise e
serialization_config = {}
for key, item in config.items():
@ -790,3 +790,8 @@ def validate_kwargs(kwargs,
for kwarg in kwargs:
if kwarg not in allowed_kwargs:
raise TypeError(error_message, kwarg)
def validate_config(config):
"""Determines whether config appears to be a valid layer config."""
return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config

View File

@ -169,3 +169,13 @@ def deserialize(proto):
def registered_identifiers():
return _REVIVED_TYPE_REGISTRY.keys()
def get_setter(proto):
_, type_registrations = _REVIVED_TYPE_REGISTRY.get(
proto.identifier, (None, None))
if type_registrations is not None:
for type_registration in type_registrations:
if type_registration.should_load(proto):
return type_registration.setter
return None