Move tf.Keras object identifiers into constants.py.

PiperOrigin-RevId: 345257730
Change-Id: I0961d32f28300fe06cac6fa16928fec119baf779
This commit is contained in:
Monica Song 2020-12-02 10:12:44 -08:00 committed by TensorFlower Gardener
parent 753786571e
commit a39726f0f4
7 changed files with 41 additions and 25 deletions

View File

@ -22,9 +22,9 @@ message SavedObject {
string node_path = 3;
// Identifier to determine loading function.
// Currently supported identifiers:
// _tf_keras_layer, _tf_keras_input_layer, _tf_keras_rnn_layer,
// _tf_keras_metric, _tf_keras_network, _tf_keras_model,
// Must be one of:
// _tf_keras_input_layer, _tf_keras_layer, _tf_keras_metric,
// _tf_keras_model, _tf_keras_network, _tf_keras_rnn_layer,
// _tf_keras_sequential
string identifier = 4;
// Metadata containing a JSON-serialized object with the non-TensorFlow

View File

@ -30,3 +30,22 @@ KERAS_CACHE_KEY = 'keras_serialized_attributes'
# Name of Keras metadata file stored in the SavedModel.
SAVED_METADATA_PATH = 'keras_metadata.pb'
# Names of SavedObject Keras identifiers.
INPUT_LAYER_IDENTIFIER = '_tf_keras_input_layer'
LAYER_IDENTIFIER = '_tf_keras_layer'
METRIC_IDENTIFIER = '_tf_keras_metric'
MODEL_IDENTIFIER = '_tf_keras_model'
NETWORK_IDENTIFIER = '_tf_keras_network'
RNN_LAYER_IDENTIFIER = '_tf_keras_rnn_layer'
SEQUENTIAL_IDENTIFIER = '_tf_keras_sequential'
KERAS_OBJECT_IDENTIFIERS = (
INPUT_LAYER_IDENTIFIER,
LAYER_IDENTIFIER,
METRIC_IDENTIFIER,
MODEL_IDENTIFIER,
NETWORK_IDENTIFIER,
RNN_LAYER_IDENTIFIER,
SEQUENTIAL_IDENTIFIER,
)

View File

@ -33,7 +33,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
@property
def object_identifier(self):
return '_tf_keras_layer'
return constants.LAYER_IDENTIFIER
@property
def python_properties(self):
@ -127,7 +127,7 @@ class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
@property
def object_identifier(self):
return '_tf_keras_input_layer'
return constants.INPUT_LAYER_IDENTIFIER
@property
def python_properties(self):
@ -153,7 +153,7 @@ class RNNSavedModelSaver(LayerSavedModelSaver):
@property
def object_identifier(self):
return '_tf_keras_rnn_layer'
return constants.RNN_LAYER_IDENTIFIER
def _get_serialized_attributes_internal(self, serialization_cache):
objects, functions = (

View File

@ -92,12 +92,6 @@ PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR)
KERAS_OBJECT_IDENTIFIERS = (
'_tf_keras_layer', '_tf_keras_input_layer', '_tf_keras_network',
'_tf_keras_model', '_tf_keras_sequential', '_tf_keras_metric',
'_tf_keras_rnn_layer')
def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
"""Loads Keras objects from a SavedModel.
@ -196,7 +190,7 @@ def _read_legacy_metadata(object_graph_def, metadata):
node_paths = _generate_object_paths(object_graph_def)
for node_id, proto in enumerate(object_graph_def.nodes):
if (proto.WhichOneof('kind') == 'user_object' and
proto.user_object.identifier in KERAS_OBJECT_IDENTIFIERS):
proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS):
metadata.nodes.add(
node_id=node_id,
node_path=node_paths[node_id],
@ -347,7 +341,7 @@ class KerasObjectLoader(object):
if (child_proto.user_object.identifier in
revived_types.registered_identifiers()):
setter = revived_types.get_setter(child_proto.user_object)
elif obj_child._object_identifier in KERAS_OBJECT_IDENTIFIERS:
elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS:
setter = _revive_setter
else:
setter = setattr
@ -384,7 +378,7 @@ class KerasObjectLoader(object):
# time by creating objects multiple times).
metric_list = []
for node_metadata in self._metadata.nodes:
if node_metadata.identifier == '_tf_keras_metric':
if node_metadata.identifier == constants.METRIC_IDENTIFIER:
metric_list.append(node_metadata)
continue
@ -432,7 +426,7 @@ class KerasObjectLoader(object):
def _revive_from_config(self, identifier, metadata, node_id):
"""Revives a layer/model from config, or returns None."""
if identifier == '_tf_keras_metric':
if identifier == constants.METRIC_IDENTIFIER:
obj = self._revive_metric_from_config(metadata)
else:
obj = (
@ -921,11 +915,12 @@ def revive_custom_object(identifier, metadata):
model_class = training_lib_v1.Model
revived_classes = {
'_tf_keras_layer': (RevivedLayer, base_layer.Layer),
'_tf_keras_input_layer': (RevivedInputLayer, input_layer.InputLayer),
'_tf_keras_network': (RevivedNetwork, functional_lib.Functional),
'_tf_keras_model': (RevivedNetwork, model_class),
'_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential),
constants.INPUT_LAYER_IDENTIFIER: (
RevivedInputLayer, input_layer.InputLayer),
constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer),
constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class),
constants.NETWORK_IDENTIFIER: (RevivedNetwork, functional_lib.Functional),
constants.SEQUENTIAL_IDENTIFIER: (RevivedNetwork, models_lib.Sequential),
}
parent_classes = revived_classes.get(identifier, None)

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras.saving.saved_model import constants
from tensorflow.python.keras.saving.saved_model import layer_serialization
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.training.tracking import data_structures
@ -28,7 +29,7 @@ class MetricSavedModelSaver(layer_serialization.LayerSavedModelSaver):
@property
def object_identifier(self):
return '_tf_keras_metric'
return constants.METRIC_IDENTIFIER
def _python_properties_internal(self):
metadata = dict(

View File

@ -29,7 +29,7 @@ class ModelSavedModelSaver(layer_serialization.LayerSavedModelSaver):
@property
def object_identifier(self):
return '_tf_keras_model'
return constants.MODEL_IDENTIFIER
def _python_properties_internal(self):
metadata = super(ModelSavedModelSaver, self)._python_properties_internal()
@ -63,4 +63,4 @@ class SequentialSavedModelSaver(ModelSavedModelSaver):
@property
def object_identifier(self):
return '_tf_keras_sequential'
return constants.SEQUENTIAL_IDENTIFIER

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras.saving.saved_model import constants
from tensorflow.python.keras.saving.saved_model import model_serialization
@ -27,4 +28,4 @@ class NetworkSavedModelSaver(model_serialization.ModelSavedModelSaver):
@property
def object_identifier(self):
return '_tf_keras_network'
return constants.NETWORK_IDENTIFIER