Move tf.Keras object identifiers into constants.py.
PiperOrigin-RevId: 345257730 Change-Id: I0961d32f28300fe06cac6fa16928fec119baf779
This commit is contained in:
parent
753786571e
commit
a39726f0f4
tensorflow/python/keras
@ -22,9 +22,9 @@ message SavedObject {
|
||||
string node_path = 3;
|
||||
|
||||
// Identifier to determine loading function.
|
||||
// Currently supported identifiers:
|
||||
// _tf_keras_layer, _tf_keras_input_layer, _tf_keras_rnn_layer,
|
||||
// _tf_keras_metric, _tf_keras_network, _tf_keras_model,
|
||||
// Must be one of:
|
||||
// _tf_keras_input_layer, _tf_keras_layer, _tf_keras_metric,
|
||||
// _tf_keras_model, _tf_keras_network, _tf_keras_rnn_layer,
|
||||
// _tf_keras_sequential
|
||||
string identifier = 4;
|
||||
// Metadata containing a JSON-serialized object with the non-TensorFlow
|
||||
|
@ -30,3 +30,22 @@ KERAS_CACHE_KEY = 'keras_serialized_attributes'
|
||||
|
||||
# Name of Keras metadata file stored in the SavedModel.
|
||||
SAVED_METADATA_PATH = 'keras_metadata.pb'
|
||||
|
||||
# Names of SavedObject Keras identifiers.
|
||||
INPUT_LAYER_IDENTIFIER = '_tf_keras_input_layer'
|
||||
LAYER_IDENTIFIER = '_tf_keras_layer'
|
||||
METRIC_IDENTIFIER = '_tf_keras_metric'
|
||||
MODEL_IDENTIFIER = '_tf_keras_model'
|
||||
NETWORK_IDENTIFIER = '_tf_keras_network'
|
||||
RNN_LAYER_IDENTIFIER = '_tf_keras_rnn_layer'
|
||||
SEQUENTIAL_IDENTIFIER = '_tf_keras_sequential'
|
||||
|
||||
KERAS_OBJECT_IDENTIFIERS = (
|
||||
INPUT_LAYER_IDENTIFIER,
|
||||
LAYER_IDENTIFIER,
|
||||
METRIC_IDENTIFIER,
|
||||
MODEL_IDENTIFIER,
|
||||
NETWORK_IDENTIFIER,
|
||||
RNN_LAYER_IDENTIFIER,
|
||||
SEQUENTIAL_IDENTIFIER,
|
||||
)
|
||||
|
@ -33,7 +33,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||||
|
||||
@property
|
||||
def object_identifier(self):
|
||||
return '_tf_keras_layer'
|
||||
return constants.LAYER_IDENTIFIER
|
||||
|
||||
@property
|
||||
def python_properties(self):
|
||||
@ -127,7 +127,7 @@ class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||||
|
||||
@property
|
||||
def object_identifier(self):
|
||||
return '_tf_keras_input_layer'
|
||||
return constants.INPUT_LAYER_IDENTIFIER
|
||||
|
||||
@property
|
||||
def python_properties(self):
|
||||
@ -153,7 +153,7 @@ class RNNSavedModelSaver(LayerSavedModelSaver):
|
||||
|
||||
@property
|
||||
def object_identifier(self):
|
||||
return '_tf_keras_rnn_layer'
|
||||
return constants.RNN_LAYER_IDENTIFIER
|
||||
|
||||
def _get_serialized_attributes_internal(self, serialization_cache):
|
||||
objects, functions = (
|
||||
|
@ -92,12 +92,6 @@ PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
|
||||
PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR)
|
||||
|
||||
|
||||
KERAS_OBJECT_IDENTIFIERS = (
|
||||
'_tf_keras_layer', '_tf_keras_input_layer', '_tf_keras_network',
|
||||
'_tf_keras_model', '_tf_keras_sequential', '_tf_keras_metric',
|
||||
'_tf_keras_rnn_layer')
|
||||
|
||||
|
||||
def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
|
||||
"""Loads Keras objects from a SavedModel.
|
||||
|
||||
@ -196,7 +190,7 @@ def _read_legacy_metadata(object_graph_def, metadata):
|
||||
node_paths = _generate_object_paths(object_graph_def)
|
||||
for node_id, proto in enumerate(object_graph_def.nodes):
|
||||
if (proto.WhichOneof('kind') == 'user_object' and
|
||||
proto.user_object.identifier in KERAS_OBJECT_IDENTIFIERS):
|
||||
proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS):
|
||||
metadata.nodes.add(
|
||||
node_id=node_id,
|
||||
node_path=node_paths[node_id],
|
||||
@ -347,7 +341,7 @@ class KerasObjectLoader(object):
|
||||
if (child_proto.user_object.identifier in
|
||||
revived_types.registered_identifiers()):
|
||||
setter = revived_types.get_setter(child_proto.user_object)
|
||||
elif obj_child._object_identifier in KERAS_OBJECT_IDENTIFIERS:
|
||||
elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS:
|
||||
setter = _revive_setter
|
||||
else:
|
||||
setter = setattr
|
||||
@ -384,7 +378,7 @@ class KerasObjectLoader(object):
|
||||
# time by creating objects multiple times).
|
||||
metric_list = []
|
||||
for node_metadata in self._metadata.nodes:
|
||||
if node_metadata.identifier == '_tf_keras_metric':
|
||||
if node_metadata.identifier == constants.METRIC_IDENTIFIER:
|
||||
metric_list.append(node_metadata)
|
||||
continue
|
||||
|
||||
@ -432,7 +426,7 @@ class KerasObjectLoader(object):
|
||||
|
||||
def _revive_from_config(self, identifier, metadata, node_id):
|
||||
"""Revives a layer/model from config, or returns None."""
|
||||
if identifier == '_tf_keras_metric':
|
||||
if identifier == constants.METRIC_IDENTIFIER:
|
||||
obj = self._revive_metric_from_config(metadata)
|
||||
else:
|
||||
obj = (
|
||||
@ -921,11 +915,12 @@ def revive_custom_object(identifier, metadata):
|
||||
model_class = training_lib_v1.Model
|
||||
|
||||
revived_classes = {
|
||||
'_tf_keras_layer': (RevivedLayer, base_layer.Layer),
|
||||
'_tf_keras_input_layer': (RevivedInputLayer, input_layer.InputLayer),
|
||||
'_tf_keras_network': (RevivedNetwork, functional_lib.Functional),
|
||||
'_tf_keras_model': (RevivedNetwork, model_class),
|
||||
'_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential),
|
||||
constants.INPUT_LAYER_IDENTIFIER: (
|
||||
RevivedInputLayer, input_layer.InputLayer),
|
||||
constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer),
|
||||
constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class),
|
||||
constants.NETWORK_IDENTIFIER: (RevivedNetwork, functional_lib.Functional),
|
||||
constants.SEQUENTIAL_IDENTIFIER: (RevivedNetwork, models_lib.Sequential),
|
||||
}
|
||||
parent_classes = revived_classes.get(identifier, None)
|
||||
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.saving.saved_model import constants
|
||||
from tensorflow.python.keras.saving.saved_model import layer_serialization
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
@ -28,7 +29,7 @@ class MetricSavedModelSaver(layer_serialization.LayerSavedModelSaver):
|
||||
|
||||
@property
|
||||
def object_identifier(self):
|
||||
return '_tf_keras_metric'
|
||||
return constants.METRIC_IDENTIFIER
|
||||
|
||||
def _python_properties_internal(self):
|
||||
metadata = dict(
|
||||
|
@ -29,7 +29,7 @@ class ModelSavedModelSaver(layer_serialization.LayerSavedModelSaver):
|
||||
|
||||
@property
|
||||
def object_identifier(self):
|
||||
return '_tf_keras_model'
|
||||
return constants.MODEL_IDENTIFIER
|
||||
|
||||
def _python_properties_internal(self):
|
||||
metadata = super(ModelSavedModelSaver, self)._python_properties_internal()
|
||||
@ -63,4 +63,4 @@ class SequentialSavedModelSaver(ModelSavedModelSaver):
|
||||
|
||||
@property
|
||||
def object_identifier(self):
|
||||
return '_tf_keras_sequential'
|
||||
return constants.SEQUENTIAL_IDENTIFIER
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.saving.saved_model import constants
|
||||
from tensorflow.python.keras.saving.saved_model import model_serialization
|
||||
|
||||
|
||||
@ -27,4 +28,4 @@ class NetworkSavedModelSaver(model_serialization.ModelSavedModelSaver):
|
||||
|
||||
@property
|
||||
def object_identifier(self):
|
||||
return '_tf_keras_network'
|
||||
return constants.NETWORK_IDENTIFIER
|
||||
|
Loading…
Reference in New Issue
Block a user